whispercpp 1.3.2 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +59 -27
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/build-xcframework.sh +24 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +154 -35
- data/ext/sources/examples/addon.node/index.js +10 -5
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +29 -18
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +7 -4
- data/ext/sources/examples/command/command.cpp +58 -32
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +21 -17
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +193 -35
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +10 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
- data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
- data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
- data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
- data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
- data/ext/sources/examples/talk-llama/llama-context.h +68 -32
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
- data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
- data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
- data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
- data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
- data/ext/sources/examples/talk-llama/llama-model.h +87 -9
- data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
- data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -17
- data/ext/sources/examples/talk-llama/llama.h +176 -151
- data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
- data/ext/sources/ggml/CMakeLists.txt +106 -33
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-backend.h +18 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml.h +365 -21
- data/ext/sources/ggml/src/CMakeLists.txt +98 -25
- data/ext/sources/ggml/src/ggml-alloc.c +265 -141
- data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
- data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
- data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
- data/ext/sources/ggml/src/ggml-common.h +21 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
- data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
- data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
- data/ext/sources/ggml/src/ggml-impl.h +229 -175
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +117 -24
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml.c +802 -142
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +32 -4
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +241 -215
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +3 -3
- data/ext/sources/tests/test-vad.cpp +2 -2
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +57 -2
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +75 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/{tests → test}/test_params.rb +8 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +246 -191
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,14 +1,16 @@
|
|
1
1
|
#include "llama-context.h"
|
2
2
|
|
3
3
|
#include "llama-impl.h"
|
4
|
+
#include "llama-batch.h"
|
4
5
|
#include "llama-io.h"
|
6
|
+
#include "llama-memory.h"
|
5
7
|
#include "llama-mmap.h"
|
6
8
|
#include "llama-model.h"
|
7
|
-
#include "llama-kv-cache.h"
|
8
9
|
|
10
|
+
#include <cinttypes>
|
9
11
|
#include <cstring>
|
12
|
+
#include <limits>
|
10
13
|
#include <stdexcept>
|
11
|
-
#include <cinttypes>
|
12
14
|
|
13
15
|
//
|
14
16
|
// llama_context
|
@@ -17,7 +19,8 @@
|
|
17
19
|
llama_context::llama_context(
|
18
20
|
const llama_model & model,
|
19
21
|
llama_context_params params) :
|
20
|
-
model(model)
|
22
|
+
model(model),
|
23
|
+
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
21
24
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
22
25
|
|
23
26
|
t_start_us = model.t_start_us;
|
@@ -26,20 +29,18 @@ llama_context::llama_context(
|
|
26
29
|
const auto & hparams = model.hparams;
|
27
30
|
|
28
31
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
29
|
-
if (cparams.n_seq_max >
|
30
|
-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(
|
32
|
+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
33
|
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
31
34
|
}
|
32
35
|
|
33
36
|
cparams.n_threads = params.n_threads;
|
34
37
|
cparams.n_threads_batch = params.n_threads_batch;
|
35
|
-
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
36
|
-
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
37
|
-
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
38
|
-
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
39
|
-
cparams.defrag_thold = params.defrag_thold;
|
38
|
+
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
|
39
|
+
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
|
40
|
+
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
41
|
+
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
40
42
|
cparams.embeddings = params.embeddings;
|
41
43
|
cparams.offload_kqv = params.offload_kqv;
|
42
|
-
cparams.flash_attn = params.flash_attn;
|
43
44
|
cparams.no_perf = params.no_perf;
|
44
45
|
cparams.pooling_type = params.pooling_type;
|
45
46
|
cparams.warmup = false;
|
@@ -84,21 +85,32 @@ llama_context::llama_context(
|
|
84
85
|
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
85
86
|
}
|
86
87
|
|
88
|
+
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
89
|
+
|
87
90
|
// with causal attention, the batch size is limited by the context size
|
88
91
|
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
89
92
|
|
90
93
|
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
91
94
|
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
92
95
|
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
93
|
-
// TODO: this padding is not needed for the cache-less context so we should probably move it to
|
96
|
+
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
|
94
97
|
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
95
98
|
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
96
99
|
cparams.n_batch = GGML_KQ_MASK_PAD;
|
97
100
|
}
|
98
|
-
|
99
101
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
100
102
|
|
101
103
|
cparams.op_offload = params.op_offload;
|
104
|
+
cparams.kv_unified = params.kv_unified;
|
105
|
+
|
106
|
+
{
|
107
|
+
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
|
108
|
+
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
|
109
|
+
|
110
|
+
if (graph_reuse_disable) {
|
111
|
+
LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__);
|
112
|
+
}
|
113
|
+
}
|
102
114
|
|
103
115
|
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
104
116
|
|
@@ -108,7 +120,8 @@ llama_context::llama_context(
|
|
108
120
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
109
121
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
110
122
|
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
111
|
-
LLAMA_LOG_INFO("%s: flash_attn = %
|
123
|
+
LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
|
124
|
+
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
112
125
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
113
126
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
114
127
|
|
@@ -168,7 +181,7 @@ llama_context::llama_context(
|
|
168
181
|
// graph outputs buffer
|
169
182
|
{
|
170
183
|
// resized during inference when a batch uses more outputs
|
171
|
-
if (
|
184
|
+
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
172
185
|
throw std::runtime_error("failed to reserve initial output buffer");
|
173
186
|
}
|
174
187
|
|
@@ -219,8 +232,8 @@ llama_context::llama_context(
|
|
219
232
|
|
220
233
|
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
221
234
|
|
222
|
-
|
223
|
-
|
235
|
+
gf_res_prev.reset(new llm_graph_result(max_nodes));
|
236
|
+
gf_res_reserve.reset(new llm_graph_result(max_nodes));
|
224
237
|
|
225
238
|
// TODO: move these checks to ggml_backend_sched
|
226
239
|
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
@@ -257,45 +270,79 @@ llama_context::llama_context(
|
|
257
270
|
}
|
258
271
|
}
|
259
272
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
273
|
+
if (!hparams.vocab_only) {
|
274
|
+
llama_memory_context_ptr mctx;
|
275
|
+
if (memory) {
|
276
|
+
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
277
|
+
mctx = memory->init_full();
|
278
|
+
if (!mctx) {
|
279
|
+
throw std::runtime_error("failed to initialize memory module");
|
280
|
+
}
|
281
|
+
}
|
282
|
+
|
283
|
+
cross.v_embd.clear();
|
264
284
|
|
265
|
-
|
285
|
+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
286
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
266
287
|
|
267
|
-
//
|
268
|
-
|
269
|
-
const auto n_outputs_save = n_outputs;
|
288
|
+
// avoid reserving graphs with zero outputs - assume one output per sequence
|
289
|
+
n_outputs = n_seqs;
|
270
290
|
|
271
291
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
272
292
|
|
293
|
+
// resolve automatic Flash Attention use
|
294
|
+
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
295
|
+
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
296
|
+
if (!gf) {
|
297
|
+
throw std::runtime_error("failed to split graph for Flash Attention check");
|
298
|
+
}
|
299
|
+
|
300
|
+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
301
|
+
bool fa_device_mismatch = false;
|
302
|
+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
303
|
+
ggml_tensor * n = ggml_graph_node(gf, i);
|
304
|
+
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
305
|
+
continue;
|
306
|
+
}
|
307
|
+
ggml_backend_dev_t device_fa = ggml_backend_get_device(
|
308
|
+
ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
309
|
+
|
310
|
+
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
311
|
+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
312
|
+
const int il = std::stoi(n->name + prefix_len);
|
313
|
+
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
314
|
+
if (device_fa != device_kv) {
|
315
|
+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
|
316
|
+
"is assigned to device %s (usually due to missing support)\n",
|
317
|
+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
|
318
|
+
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
|
319
|
+
fa_device_mismatch = true;
|
320
|
+
break;
|
321
|
+
}
|
322
|
+
}
|
323
|
+
if (fa_device_mismatch) {
|
324
|
+
cparams.flash_attn = false;
|
325
|
+
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
326
|
+
if (ggml_is_quantized(params.type_v)) {
|
327
|
+
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
328
|
+
}
|
329
|
+
} else {
|
330
|
+
cparams.flash_attn = true;
|
331
|
+
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
332
|
+
}
|
333
|
+
}
|
334
|
+
|
335
|
+
// reserve worst-case graph
|
273
336
|
int n_splits_pp = -1;
|
274
337
|
int n_nodes_pp = -1;
|
275
338
|
|
276
339
|
int n_splits_tg = -1;
|
277
340
|
int n_nodes_tg = -1;
|
278
341
|
|
279
|
-
//
|
280
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
281
|
-
|
282
|
-
kv_self->set_full();
|
283
|
-
|
284
|
-
cross.v_embd.clear();
|
285
|
-
|
286
|
-
// reserve pp graph first so that buffers are only allocated once
|
342
|
+
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
287
343
|
{
|
288
|
-
|
289
|
-
|
290
|
-
// max number of outputs
|
291
|
-
n_outputs = ubatch_pp.n_tokens;
|
292
|
-
|
293
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
294
|
-
|
295
|
-
auto * gf = graph_init();
|
296
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
297
|
-
|
298
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
344
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
345
|
+
if (!gf) {
|
299
346
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
300
347
|
}
|
301
348
|
|
@@ -303,18 +350,10 @@ llama_context::llama_context(
|
|
303
350
|
n_nodes_pp = ggml_graph_n_nodes(gf);
|
304
351
|
}
|
305
352
|
|
306
|
-
// reserve with tg graph to get the number of splits and nodes
|
353
|
+
// reserve with tg (token generation) graph to get the number of splits and nodes
|
307
354
|
{
|
308
|
-
|
309
|
-
|
310
|
-
n_outputs = ubatch_tg.n_tokens;
|
311
|
-
|
312
|
-
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
313
|
-
|
314
|
-
auto * gf = graph_init();
|
315
|
-
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
316
|
-
|
317
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
355
|
+
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
356
|
+
if (!gf) {
|
318
357
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
319
358
|
}
|
320
359
|
|
@@ -324,22 +363,16 @@ llama_context::llama_context(
|
|
324
363
|
|
325
364
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
326
365
|
{
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
auto * gf = graph_init();
|
334
|
-
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
335
|
-
|
336
|
-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
366
|
+
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
367
|
+
//
|
368
|
+
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
369
|
+
//
|
370
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
371
|
+
if (!gf) {
|
337
372
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
338
373
|
}
|
339
374
|
}
|
340
375
|
|
341
|
-
n_outputs = n_outputs_save;
|
342
|
-
|
343
376
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
344
377
|
ggml_backend_t backend = backend_ptrs[i];
|
345
378
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
@@ -411,10 +444,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
|
|
411
444
|
return sched.get();
|
412
445
|
}
|
413
446
|
|
414
|
-
ggml_context * llama_context::get_ctx_compute() const {
|
415
|
-
return ctx_compute.get();
|
416
|
-
}
|
417
|
-
|
418
447
|
uint32_t llama_context::n_ctx() const {
|
419
448
|
return cparams.n_ctx;
|
420
449
|
}
|
@@ -443,46 +472,62 @@ uint32_t llama_context::n_threads_batch() const {
|
|
443
472
|
return cparams.n_threads_batch;
|
444
473
|
}
|
445
474
|
|
446
|
-
|
447
|
-
|
448
|
-
return kv_self;
|
449
|
-
}
|
450
|
-
|
451
|
-
const llama_kv_cache * llama_context::get_kv_self() const {
|
452
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
453
|
-
return kv_self;
|
475
|
+
llama_memory_t llama_context::get_memory() const {
|
476
|
+
return memory.get();
|
454
477
|
}
|
455
478
|
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
need_reserve = kv_self->update(*this);
|
479
|
+
bool llama_context::memory_update(bool optimize) {
|
480
|
+
if (!memory) {
|
481
|
+
return false;
|
482
|
+
}
|
462
483
|
|
463
|
-
|
464
|
-
|
465
|
-
|
484
|
+
{
|
485
|
+
const auto mctx = memory->init_update(this, optimize);
|
486
|
+
switch (mctx->get_status()) {
|
487
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
488
|
+
{
|
489
|
+
// noop
|
490
|
+
} break;
|
491
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
492
|
+
{
|
493
|
+
// no updates need to be performed
|
494
|
+
return false;
|
495
|
+
}
|
496
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
497
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
498
|
+
{
|
499
|
+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
500
|
+
return false;
|
501
|
+
}
|
502
|
+
}
|
466
503
|
|
467
|
-
//
|
468
|
-
|
469
|
-
|
504
|
+
// reset the previous graph result to make sure that it won't be reused
|
505
|
+
// TODO: change the mctx->apply() to return information if a graph reserve is needed
|
506
|
+
// reset the graph result only if the memory module did reset the scheduler
|
507
|
+
gf_res_prev->reset();
|
470
508
|
|
471
|
-
|
472
|
-
|
509
|
+
if (!mctx->apply()) {
|
510
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
511
|
+
}
|
512
|
+
}
|
473
513
|
|
474
|
-
|
475
|
-
|
514
|
+
// if the memory module did any computation, we have to reserve a new worst-case graph
|
515
|
+
{
|
516
|
+
const auto mctx = memory->init_full();
|
517
|
+
if (!mctx) {
|
518
|
+
throw std::runtime_error("failed to initialize memory context");
|
519
|
+
}
|
476
520
|
|
477
|
-
|
478
|
-
|
521
|
+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
522
|
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
479
523
|
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
524
|
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
525
|
+
if (!gf) {
|
526
|
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
484
527
|
}
|
485
528
|
}
|
529
|
+
|
530
|
+
return true;
|
486
531
|
}
|
487
532
|
|
488
533
|
enum llama_pooling_type llama_context::pooling_type() const {
|
@@ -490,11 +535,15 @@ enum llama_pooling_type llama_context::pooling_type() const {
|
|
490
535
|
}
|
491
536
|
|
492
537
|
float * llama_context::get_logits() {
|
538
|
+
output_reorder();
|
539
|
+
|
493
540
|
return logits;
|
494
541
|
}
|
495
542
|
|
496
543
|
float * llama_context::get_logits_ith(int32_t i) {
|
497
|
-
|
544
|
+
int64_t j = -1;
|
545
|
+
|
546
|
+
output_reorder();
|
498
547
|
|
499
548
|
try {
|
500
549
|
if (logits == nullptr) {
|
@@ -517,7 +566,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
517
566
|
}
|
518
567
|
if (j >= n_outputs) {
|
519
568
|
// This should not happen
|
520
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
569
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
521
570
|
}
|
522
571
|
|
523
572
|
return logits + j*model.vocab.n_tokens();
|
@@ -532,11 +581,15 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
532
581
|
}
|
533
582
|
|
534
583
|
float * llama_context::get_embeddings() {
|
584
|
+
output_reorder();
|
585
|
+
|
535
586
|
return embd;
|
536
587
|
}
|
537
588
|
|
538
589
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
539
|
-
|
590
|
+
int64_t j = -1;
|
591
|
+
|
592
|
+
output_reorder();
|
540
593
|
|
541
594
|
try {
|
542
595
|
if (embd == nullptr) {
|
@@ -559,7 +612,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
559
612
|
}
|
560
613
|
if (j >= n_outputs) {
|
561
614
|
// This should not happen
|
562
|
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
615
|
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
563
616
|
}
|
564
617
|
|
565
618
|
return embd + j*model.hparams.n_embd;
|
@@ -676,72 +729,119 @@ bool llama_context::apply_adapter_cvec(
|
|
676
729
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
677
730
|
}
|
678
731
|
|
679
|
-
|
680
|
-
if (
|
681
|
-
LLAMA_LOG_ERROR("%s:
|
682
|
-
|
732
|
+
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
733
|
+
if (mctx && !mctx->apply()) {
|
734
|
+
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
735
|
+
ret = GGML_STATUS_FAILED;
|
736
|
+
return nullptr;
|
683
737
|
}
|
684
738
|
|
685
|
-
|
686
|
-
|
687
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
|
739
|
+
auto * res = gf_res_prev.get();
|
740
|
+
auto * gf = res->get_gf();
|
688
741
|
|
689
|
-
|
690
|
-
|
742
|
+
// the new graph parameters
|
743
|
+
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
|
744
|
+
const auto gparams = graph_params(res, ubatch, mctx, gtype);
|
691
745
|
|
692
|
-
|
746
|
+
if (!graph_reuse_disable && res->can_reuse(gparams)) {
|
747
|
+
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
|
693
748
|
|
694
|
-
|
749
|
+
n_reused++;
|
750
|
+
} else {
|
751
|
+
res->reset();
|
695
752
|
|
696
|
-
|
697
|
-
|
698
|
-
for (int32_t i = 0; i < n_tokens; ++i) {
|
699
|
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
700
|
-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
701
|
-
return -1;
|
702
|
-
}
|
753
|
+
ggml_backend_sched_reset(sched.get());
|
754
|
+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
703
755
|
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
756
|
+
//const auto t_start_us = ggml_time_us();
|
757
|
+
|
758
|
+
gf = model.build_graph(gparams);
|
759
|
+
|
760
|
+
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
761
|
+
|
762
|
+
if (!gf) {
|
763
|
+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
764
|
+
ret = GGML_STATUS_FAILED;
|
765
|
+
return nullptr;
|
766
|
+
}
|
767
|
+
|
768
|
+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
769
|
+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
770
|
+
ret = GGML_STATUS_ALLOC_FAILED;
|
771
|
+
return nullptr;
|
708
772
|
}
|
709
773
|
}
|
710
774
|
|
775
|
+
// set the input data for the input tensors
|
776
|
+
{
|
777
|
+
//const auto t_start_us = ggml_time_us();
|
778
|
+
|
779
|
+
res->set_inputs(&ubatch);
|
780
|
+
|
781
|
+
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
782
|
+
}
|
783
|
+
|
784
|
+
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
|
785
|
+
if (status != GGML_STATUS_SUCCESS) {
|
786
|
+
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
787
|
+
ret = status;
|
788
|
+
return nullptr;
|
789
|
+
}
|
790
|
+
|
791
|
+
ret = GGML_STATUS_SUCCESS;
|
792
|
+
|
793
|
+
return res;
|
794
|
+
}
|
795
|
+
|
796
|
+
int llama_context::encode(const llama_batch & batch_inp) {
|
797
|
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
798
|
+
|
799
|
+
if (batch_inp.n_tokens == 0) {
|
800
|
+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
801
|
+
return -1;
|
802
|
+
}
|
803
|
+
|
804
|
+
const auto & hparams = model.hparams;
|
805
|
+
|
806
|
+
const int64_t n_embd = hparams.n_embd;
|
807
|
+
const int64_t n_vocab = model.vocab.n_tokens();
|
808
|
+
|
809
|
+
// note: during encode, we always pass the full sequence starting from pos = 0
|
810
|
+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
811
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
812
|
+
return -1;
|
813
|
+
}
|
814
|
+
|
815
|
+
const uint32_t n_tokens = balloc->get_n_tokens();
|
816
|
+
|
817
|
+
// [TAG_NO_CACHE_PAD]
|
818
|
+
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
|
819
|
+
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
820
|
+
|
711
821
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
712
|
-
GGML_ASSERT(cparams.n_ubatch >=
|
822
|
+
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
713
823
|
|
714
824
|
if (t_compute_start_us == 0) {
|
715
825
|
t_compute_start_us = ggml_time_us();
|
716
826
|
}
|
717
827
|
|
828
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
718
829
|
embd_seq.clear();
|
719
830
|
|
720
831
|
n_queued_tokens += n_tokens;
|
721
832
|
|
722
|
-
const int64_t n_embd = hparams.n_embd;
|
723
|
-
|
724
|
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
725
|
-
|
726
|
-
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
727
|
-
|
728
833
|
// reserve output buffer
|
729
834
|
if (output_reserve(n_tokens) < n_tokens) {
|
730
835
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
731
836
|
return -2;
|
732
837
|
};
|
733
838
|
|
734
|
-
for (
|
839
|
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
735
840
|
output_ids[i] = i;
|
736
841
|
}
|
737
842
|
|
738
843
|
n_outputs = n_tokens;
|
739
844
|
|
740
|
-
//batch_manager->prepare(ubatch);
|
741
|
-
|
742
|
-
ggml_backend_sched_reset(sched.get());
|
743
|
-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
744
|
-
|
745
845
|
const auto causal_attn_org = cparams.causal_attn;
|
746
846
|
|
747
847
|
// always use non-causal attention for encoder graphs
|
@@ -749,32 +849,34 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
749
849
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
750
850
|
cparams.causal_attn = false;
|
751
851
|
|
752
|
-
|
753
|
-
auto res =
|
754
|
-
|
755
|
-
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
756
|
-
|
757
|
-
res->set_inputs(&ubatch);
|
852
|
+
ggml_status status;
|
853
|
+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
758
854
|
|
759
855
|
cparams.causal_attn = causal_attn_org;
|
760
856
|
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
return -2;
|
769
|
-
case GGML_STATUS_FAILED:
|
770
|
-
default:
|
771
|
-
return -3;
|
857
|
+
if (!res) {
|
858
|
+
switch (status) {
|
859
|
+
case GGML_STATUS_ABORTED: return 2;
|
860
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
861
|
+
case GGML_STATUS_FAILED: return -3;
|
862
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
863
|
+
}
|
772
864
|
}
|
773
865
|
|
866
|
+
auto * t_logits = res->get_logits();
|
774
867
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
775
868
|
|
869
|
+
// extract logits
|
870
|
+
if (logits && t_logits) {
|
871
|
+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
872
|
+
GGML_ASSERT(backend_res != nullptr);
|
873
|
+
GGML_ASSERT(logits != nullptr);
|
874
|
+
|
875
|
+
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
|
876
|
+
}
|
877
|
+
|
776
878
|
// extract embeddings
|
777
|
-
if (t_embd) {
|
879
|
+
if (embd && t_embd) {
|
778
880
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
779
881
|
GGML_ASSERT(backend_embd != nullptr);
|
780
882
|
|
@@ -793,31 +895,28 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
793
895
|
{
|
794
896
|
// extract sequence embeddings
|
795
897
|
auto & embd_seq_out = embd_seq;
|
796
|
-
embd_seq_out.clear();
|
797
898
|
|
798
|
-
|
899
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
900
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
901
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
799
902
|
|
800
|
-
for (int32_t i = 0; i < n_tokens; i++) {
|
801
|
-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
802
|
-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
803
|
-
continue;
|
804
|
-
}
|
805
903
|
embd_seq_out[seq_id].resize(n_embd);
|
806
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
904
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
807
905
|
}
|
808
906
|
} break;
|
809
907
|
case LLAMA_POOLING_TYPE_RANK:
|
810
908
|
{
|
811
|
-
// extract the rerank score -
|
909
|
+
// extract the rerank score - n_cls_out floats per sequence
|
812
910
|
auto & embd_seq_out = embd_seq;
|
813
911
|
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
912
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
913
|
+
|
914
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
915
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
916
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
917
|
+
|
918
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
919
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
821
920
|
}
|
822
921
|
} break;
|
823
922
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
@@ -827,10 +926,6 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
827
926
|
}
|
828
927
|
}
|
829
928
|
|
830
|
-
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
831
|
-
// overlap with device computation.
|
832
|
-
ggml_backend_sched_reset(sched.get());
|
833
|
-
|
834
929
|
// TODO: hacky solution
|
835
930
|
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
836
931
|
//cross.t_embd = t_embd;
|
@@ -842,12 +937,16 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
842
937
|
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
843
938
|
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
844
939
|
|
940
|
+
const auto & batch = balloc->get_batch();
|
941
|
+
|
845
942
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
846
943
|
cross.seq_ids_enc.resize(n_tokens);
|
847
|
-
for (
|
944
|
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
848
945
|
cross.seq_ids_enc[i].clear();
|
849
|
-
|
850
|
-
|
946
|
+
|
947
|
+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
948
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
949
|
+
|
851
950
|
cross.seq_ids_enc[i].insert(seq_id);
|
852
951
|
}
|
853
952
|
}
|
@@ -856,55 +955,42 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
856
955
|
return 0;
|
857
956
|
}
|
858
957
|
|
859
|
-
int llama_context::decode(llama_batch &
|
958
|
+
int llama_context::decode(const llama_batch & batch_inp) {
|
959
|
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
960
|
+
|
860
961
|
if (!memory) {
|
861
962
|
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
862
|
-
return encode(
|
963
|
+
return encode(batch_inp);
|
863
964
|
}
|
864
965
|
|
865
|
-
if (
|
966
|
+
if (batch_inp.n_tokens == 0) {
|
866
967
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
867
968
|
return -1;
|
868
969
|
}
|
869
970
|
|
870
|
-
if (!inp_batch.pos) {
|
871
|
-
if (inp_batch.seq_id) {
|
872
|
-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
873
|
-
return -1;
|
874
|
-
}
|
875
|
-
}
|
876
|
-
|
877
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
878
|
-
|
879
|
-
// temporary allocate memory for the input batch if needed
|
880
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
881
|
-
|
882
|
-
const llama_batch & batch = batch_allocr.batch;
|
883
|
-
|
884
971
|
const auto & vocab = model.vocab;
|
885
972
|
const auto & hparams = model.hparams;
|
886
973
|
|
887
|
-
const
|
888
|
-
|
889
|
-
const int64_t n_tokens_all = batch.n_tokens;
|
890
|
-
const int64_t n_embd = hparams.n_embd;
|
974
|
+
const int64_t n_vocab = vocab.n_tokens();
|
975
|
+
const int64_t n_embd = hparams.n_embd;
|
891
976
|
|
892
|
-
|
977
|
+
// when computing embeddings, all tokens are output
|
978
|
+
const bool output_all = cparams.embeddings;
|
893
979
|
|
894
|
-
|
980
|
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
981
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
982
|
+
return -1;
|
983
|
+
}
|
895
984
|
|
896
|
-
|
897
|
-
|
898
|
-
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
899
|
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
900
|
-
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
901
|
-
return -1;
|
902
|
-
}
|
985
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
986
|
+
const uint32_t n_outputs_all = balloc->get_n_outputs();
|
903
987
|
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
988
|
+
if (output_all) {
|
989
|
+
// require that all tokens are output
|
990
|
+
if (n_outputs_all != n_tokens_all) {
|
991
|
+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
992
|
+
__func__, n_outputs_all, n_tokens_all);
|
993
|
+
return -1;
|
908
994
|
}
|
909
995
|
}
|
910
996
|
|
@@ -917,49 +1003,78 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
917
1003
|
}
|
918
1004
|
n_queued_tokens += n_tokens_all;
|
919
1005
|
|
920
|
-
// this
|
921
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
922
|
-
|
1006
|
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
923
1007
|
embd_seq.clear();
|
1008
|
+
output_swaps.clear();
|
924
1009
|
|
925
|
-
|
1010
|
+
bool did_optimize = false;
|
926
1011
|
|
927
|
-
//
|
928
|
-
|
929
|
-
|
930
|
-
|
1012
|
+
// handle any pending shifts/copies
|
1013
|
+
memory_update(false);
|
1014
|
+
|
1015
|
+
llama_memory_context_ptr mctx;
|
1016
|
+
|
1017
|
+
while (true) {
|
1018
|
+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
1019
|
+
if (!mctx) {
|
1020
|
+
return -2;
|
931
1021
|
}
|
932
|
-
} else if (embd_pooled) {
|
933
|
-
n_outputs_all = n_tokens_all;
|
934
|
-
} else {
|
935
|
-
// keep last output only
|
936
|
-
n_outputs_all = 1;
|
937
|
-
}
|
938
1022
|
|
939
|
-
|
1023
|
+
switch (mctx->get_status()) {
|
1024
|
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
1025
|
+
{
|
1026
|
+
} break;
|
1027
|
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
1028
|
+
{
|
1029
|
+
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
1030
|
+
|
1031
|
+
return -2;
|
1032
|
+
}
|
1033
|
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
1034
|
+
{
|
1035
|
+
if (!did_optimize) {
|
1036
|
+
did_optimize = true;
|
1037
|
+
|
1038
|
+
if (memory_update(true)) {
|
1039
|
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
1040
|
+
|
1041
|
+
continue;
|
1042
|
+
}
|
1043
|
+
}
|
1044
|
+
|
1045
|
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
|
1046
|
+
|
1047
|
+
return 1;
|
1048
|
+
}
|
1049
|
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
1050
|
+
{
|
1051
|
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
|
1052
|
+
|
1053
|
+
return -2;
|
1054
|
+
}
|
1055
|
+
}
|
1056
|
+
|
1057
|
+
break;
|
1058
|
+
}
|
940
1059
|
|
941
1060
|
// reserve output buffer
|
942
1061
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
943
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
1062
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
944
1063
|
return -2;
|
945
1064
|
};
|
946
1065
|
|
947
|
-
// handle any pending defrags/shifts
|
948
|
-
kv_self_update();
|
949
|
-
|
950
1066
|
int64_t n_outputs_prev = 0;
|
951
1067
|
|
952
|
-
|
953
|
-
|
1068
|
+
do {
|
1069
|
+
const auto & ubatch = mctx->get_ubatch();
|
954
1070
|
|
955
|
-
// count the outputs in this
|
1071
|
+
// count the outputs in this ubatch
|
956
1072
|
{
|
957
1073
|
int32_t n_outputs_new = 0;
|
958
1074
|
|
959
1075
|
if (n_outputs_all == n_tokens_all) {
|
960
1076
|
n_outputs_new = ubatch.n_tokens;
|
961
1077
|
} else {
|
962
|
-
GGML_ASSERT(ubatch.output);
|
963
1078
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
964
1079
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
965
1080
|
}
|
@@ -969,33 +1084,37 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
969
1084
|
n_outputs = n_outputs_new;
|
970
1085
|
}
|
971
1086
|
|
972
|
-
|
973
|
-
|
974
|
-
return 1;
|
975
|
-
}
|
1087
|
+
ggml_status status;
|
1088
|
+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
976
1089
|
|
977
|
-
|
978
|
-
|
1090
|
+
if (!res) {
|
1091
|
+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
|
1092
|
+
llama_pos pos_min[LLAMA_MAX_SEQ];
|
1093
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
1094
|
+
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
1095
|
+
}
|
979
1096
|
|
980
|
-
|
981
|
-
|
1097
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
1098
|
+
const auto & seq_id = ubatch.seq_id[i][0];
|
982
1099
|
|
983
|
-
|
1100
|
+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
1101
|
+
}
|
984
1102
|
|
985
|
-
|
1103
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
1104
|
+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
1105
|
+
continue;
|
1106
|
+
}
|
986
1107
|
|
987
|
-
|
1108
|
+
LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
988
1109
|
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
case GGML_STATUS_ALLOC_FAILED:
|
995
|
-
|
996
|
-
case
|
997
|
-
default:
|
998
|
-
return -3;
|
1110
|
+
memory->seq_rm(s, pos_min[s], -1);
|
1111
|
+
}
|
1112
|
+
|
1113
|
+
switch (status) {
|
1114
|
+
case GGML_STATUS_ABORTED: return 2;
|
1115
|
+
case GGML_STATUS_ALLOC_FAILED: return -2;
|
1116
|
+
case GGML_STATUS_FAILED: return -3;
|
1117
|
+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
999
1118
|
}
|
1000
1119
|
}
|
1001
1120
|
|
@@ -1004,7 +1123,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1004
1123
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
1005
1124
|
//}
|
1006
1125
|
|
1007
|
-
auto * t_logits =
|
1126
|
+
auto * t_logits = res->get_logits();
|
1008
1127
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
1009
1128
|
|
1010
1129
|
if (t_embd && res->get_embd_pooled()) {
|
@@ -1051,27 +1170,27 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1051
1170
|
// extract sequence embeddings (cleared before processing each batch)
|
1052
1171
|
auto & embd_seq_out = embd_seq;
|
1053
1172
|
|
1054
|
-
for (uint32_t s = 0; s < ubatch.
|
1055
|
-
const llama_seq_id seq_id
|
1056
|
-
|
1057
|
-
|
1058
|
-
}
|
1173
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
1174
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
1175
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
1176
|
+
|
1059
1177
|
embd_seq_out[seq_id].resize(n_embd);
|
1060
|
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
1178
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
1061
1179
|
}
|
1062
1180
|
} break;
|
1063
1181
|
case LLAMA_POOLING_TYPE_RANK:
|
1064
1182
|
{
|
1065
|
-
// extract the rerank score -
|
1183
|
+
// extract the rerank score - n_cls_out floats per sequence
|
1066
1184
|
auto & embd_seq_out = embd_seq;
|
1067
1185
|
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1186
|
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
1187
|
+
|
1188
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
1189
|
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
1190
|
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
1191
|
+
|
1192
|
+
embd_seq_out[seq_id].resize(n_cls_out);
|
1193
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
1075
1194
|
}
|
1076
1195
|
} break;
|
1077
1196
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
@@ -1082,23 +1201,20 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1082
1201
|
}
|
1083
1202
|
|
1084
1203
|
n_outputs_prev += n_outputs;
|
1085
|
-
}
|
1086
|
-
|
1087
|
-
// finalize the batch processing
|
1088
|
-
kv_guard.commit();
|
1204
|
+
} while (mctx->next());
|
1089
1205
|
|
1090
1206
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
1091
1207
|
n_outputs = n_outputs_all;
|
1092
1208
|
|
1093
1209
|
// set output mappings
|
1094
|
-
{
|
1210
|
+
if (n_outputs > 0) {
|
1095
1211
|
bool sorted_output = true;
|
1096
1212
|
|
1097
|
-
auto & out_ids =
|
1213
|
+
auto & out_ids = balloc->get_out_ids();
|
1098
1214
|
|
1099
|
-
GGML_ASSERT(out_ids.size() == (size_t)
|
1215
|
+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
1100
1216
|
|
1101
|
-
for (int64_t i = 0; i <
|
1217
|
+
for (int64_t i = 0; i < n_outputs; ++i) {
|
1102
1218
|
int64_t out_id = out_ids[i];
|
1103
1219
|
output_ids[out_id] = i;
|
1104
1220
|
if (out_id != i) {
|
@@ -1109,35 +1225,29 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1109
1225
|
// make the outputs have the same order they had in the user-provided batch
|
1110
1226
|
// note: this is mostly relevant for recurrent models atm
|
1111
1227
|
if (!sorted_output) {
|
1112
|
-
const uint32_t n_vocab = model.vocab.n_tokens();
|
1113
|
-
const uint32_t n_embd = model.hparams.n_embd;
|
1114
|
-
|
1115
1228
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
1116
1229
|
|
1117
1230
|
// TODO: is there something more efficient which also minimizes swaps?
|
1118
1231
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
1119
|
-
for (
|
1120
|
-
|
1121
|
-
for (
|
1232
|
+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
1233
|
+
uint32_t j_min = i;
|
1234
|
+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
1122
1235
|
if (out_ids[j] < out_ids[j_min]) {
|
1123
1236
|
j_min = j;
|
1124
1237
|
}
|
1125
1238
|
}
|
1126
|
-
if (j_min == i) {
|
1127
|
-
|
1128
|
-
if (logits_size > 0) {
|
1129
|
-
for (uint32_t k = 0; k < n_vocab; k++) {
|
1130
|
-
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
1131
|
-
}
|
1132
|
-
}
|
1133
|
-
if (embd_size > 0) {
|
1134
|
-
for (uint32_t k = 0; k < n_embd; k++) {
|
1135
|
-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
1136
|
-
}
|
1239
|
+
if (j_min == i) {
|
1240
|
+
continue;
|
1137
1241
|
}
|
1242
|
+
std::swap(out_ids[i], out_ids[j_min]);
|
1243
|
+
|
1244
|
+
// remember the swaps and apply them lazily upon logits/embeddings access
|
1245
|
+
output_swaps.push_back({ i, j_min });
|
1138
1246
|
}
|
1247
|
+
|
1139
1248
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1140
|
-
|
1249
|
+
|
1250
|
+
for (uint32_t i = 0; i < n_outputs; ++i) {
|
1141
1251
|
output_ids[out_ids[i]] = i;
|
1142
1252
|
}
|
1143
1253
|
}
|
@@ -1146,15 +1256,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1146
1256
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
1147
1257
|
//synchronize();
|
1148
1258
|
|
1149
|
-
// decide if we need to defrag the kv cache
|
1150
|
-
if (cparams.defrag_thold > 0.0f) {
|
1151
|
-
kv_self->defrag_sched(cparams.defrag_thold);
|
1152
|
-
}
|
1153
|
-
|
1154
|
-
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
1155
|
-
// overlap with device computation.
|
1156
|
-
ggml_backend_sched_reset(sched.get());
|
1157
|
-
|
1158
1259
|
return 0;
|
1159
1260
|
}
|
1160
1261
|
|
@@ -1162,7 +1263,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
1162
1263
|
// output
|
1163
1264
|
//
|
1164
1265
|
|
1165
|
-
|
1266
|
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
1166
1267
|
const auto & hparams = model.hparams;
|
1167
1268
|
const auto & vocab = model.vocab;
|
1168
1269
|
|
@@ -1172,9 +1273,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1172
1273
|
const auto n_vocab = vocab.n_tokens();
|
1173
1274
|
const auto n_embd = hparams.n_embd;
|
1174
1275
|
|
1175
|
-
|
1176
|
-
bool
|
1177
|
-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
1276
|
+
bool has_logits = true;
|
1277
|
+
bool has_embd = cparams.embeddings;
|
1178
1278
|
|
1179
1279
|
// TODO: hacky enc-dec support
|
1180
1280
|
if (model.arch == LLM_ARCH_T5) {
|
@@ -1228,53 +1328,114 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
1228
1328
|
// set all ids as invalid (negative)
|
1229
1329
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
1230
1330
|
|
1231
|
-
this->n_outputs
|
1232
|
-
this->n_outputs_max = n_outputs_max;
|
1331
|
+
this->n_outputs = 0;
|
1233
1332
|
|
1234
1333
|
return n_outputs_max;
|
1235
1334
|
}
|
1236
1335
|
|
1336
|
+
void llama_context::output_reorder() {
|
1337
|
+
const uint64_t n_vocab = model.vocab.n_tokens();
|
1338
|
+
const uint64_t n_embd = model.hparams.n_embd;
|
1339
|
+
|
1340
|
+
for (size_t s = 0; s < output_swaps.size(); ++s) {
|
1341
|
+
const uint64_t i0 = output_swaps[s].i0;
|
1342
|
+
const uint64_t i1 = output_swaps[s].i1;
|
1343
|
+
|
1344
|
+
if (logits_size > 0) {
|
1345
|
+
for (uint64_t k = 0; k < n_vocab; k++) {
|
1346
|
+
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
|
1347
|
+
}
|
1348
|
+
}
|
1349
|
+
|
1350
|
+
if (embd_size > 0) {
|
1351
|
+
for (uint64_t k = 0; k < n_embd; k++) {
|
1352
|
+
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
1353
|
+
}
|
1354
|
+
}
|
1355
|
+
}
|
1356
|
+
|
1357
|
+
output_swaps.clear();
|
1358
|
+
}
|
1359
|
+
|
1237
1360
|
//
|
1238
1361
|
// graph
|
1239
1362
|
//
|
1240
1363
|
|
1241
|
-
|
1242
|
-
return std::max<
|
1364
|
+
uint32_t llama_context::graph_max_nodes() const {
|
1365
|
+
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
1243
1366
|
}
|
1244
1367
|
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
/*.mem_buffer =*/ buf_compute_meta.data(),
|
1249
|
-
/*.no_alloc =*/ true,
|
1250
|
-
};
|
1368
|
+
llm_graph_result * llama_context::get_gf_res_reserve() const {
|
1369
|
+
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
1370
|
+
}
|
1251
1371
|
|
1252
|
-
|
1372
|
+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
|
1373
|
+
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
1374
|
+
GGML_ASSERT(n_outputs >= 1);
|
1253
1375
|
|
1254
|
-
|
1255
|
-
|
1376
|
+
if (n_tokens % n_seqs != 0) {
|
1377
|
+
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
1378
|
+
n_outputs = std::min(n_outputs, n_tokens);
|
1256
1379
|
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1380
|
+
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
1381
|
+
}
|
1382
|
+
|
1383
|
+
ggml_backend_sched_reset(sched.get());
|
1384
|
+
|
1385
|
+
// when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
|
1386
|
+
gf_res_prev->reset();
|
1387
|
+
|
1388
|
+
// store the n_outputs as it is, and restore it afterwards
|
1389
|
+
// TODO: not sure if needed, might simplify in the future by removing this
|
1390
|
+
const auto save_n_outputs = this->n_outputs;
|
1391
|
+
|
1392
|
+
this->n_outputs = n_outputs;
|
1393
|
+
|
1394
|
+
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
1395
|
+
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
1396
|
+
|
1397
|
+
auto * res = gf_res_reserve.get();
|
1398
|
+
|
1399
|
+
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
1400
|
+
|
1401
|
+
res->reset();
|
1402
|
+
|
1403
|
+
auto * gf = model.build_graph(gparams);
|
1404
|
+
|
1405
|
+
this->n_outputs = save_n_outputs;
|
1406
|
+
|
1407
|
+
// initialize scheduler with the specified graph
|
1408
|
+
if (split_only) {
|
1409
|
+
ggml_backend_sched_split_graph(sched.get(), gf);
|
1410
|
+
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
1411
|
+
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
1412
|
+
return nullptr;
|
1413
|
+
}
|
1414
|
+
|
1415
|
+
return gf;
|
1416
|
+
}
|
1417
|
+
|
1418
|
+
llm_graph_params llama_context::graph_params(
|
1419
|
+
llm_graph_result * res,
|
1420
|
+
const llama_ubatch & ubatch,
|
1421
|
+
const llama_memory_context_i * mctx,
|
1422
|
+
llm_graph_type gtype) const {
|
1423
|
+
return {
|
1424
|
+
/*.arch =*/ model.arch,
|
1425
|
+
/*.hparams =*/ model.hparams,
|
1426
|
+
/*.cparams =*/ cparams,
|
1427
|
+
/*.ubatch =*/ ubatch,
|
1428
|
+
/*.gtype =*/ gtype,
|
1429
|
+
/*.sched =*/ sched.get(),
|
1430
|
+
/*.backend_cpu =*/ backend_cpu,
|
1431
|
+
/*.cvec =*/ &cvec,
|
1432
|
+
/*.loras =*/ &loras,
|
1433
|
+
/*.mctx =*/ mctx,
|
1434
|
+
/*.cross =*/ &cross,
|
1435
|
+
/*.n_outputs =*/ n_outputs,
|
1436
|
+
/*.cb =*/ graph_get_cb(),
|
1437
|
+
/*.res =*/ res,
|
1438
|
+
};
|
1278
1439
|
}
|
1279
1440
|
|
1280
1441
|
ggml_status llama_context::graph_compute(
|
@@ -1286,7 +1447,9 @@ ggml_status llama_context::graph_compute(
|
|
1286
1447
|
if (backend_cpu != nullptr) {
|
1287
1448
|
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
|
1288
1449
|
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
|
1289
|
-
set_threadpool_fn
|
1450
|
+
if (set_threadpool_fn) {
|
1451
|
+
set_threadpool_fn(backend_cpu, tp);
|
1452
|
+
}
|
1290
1453
|
}
|
1291
1454
|
|
1292
1455
|
// set the number of threads for all the backends
|
@@ -1505,30 +1668,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
|
1505
1668
|
}
|
1506
1669
|
}
|
1507
1670
|
|
1508
|
-
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
1671
|
+
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
|
1509
1672
|
llama_io_write_dummy io;
|
1510
1673
|
try {
|
1511
|
-
return state_seq_write_data(io, seq_id);
|
1674
|
+
return state_seq_write_data(io, seq_id, flags);
|
1512
1675
|
} catch (const std::exception & err) {
|
1513
1676
|
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
1514
1677
|
return 0;
|
1515
1678
|
}
|
1516
1679
|
}
|
1517
1680
|
|
1518
|
-
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
1681
|
+
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
|
1519
1682
|
llama_io_write_buffer io(dst, size);
|
1520
1683
|
try {
|
1521
|
-
return state_seq_write_data(io, seq_id);
|
1684
|
+
return state_seq_write_data(io, seq_id, flags);
|
1522
1685
|
} catch (const std::exception & err) {
|
1523
1686
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
1524
1687
|
return 0;
|
1525
1688
|
}
|
1526
1689
|
}
|
1527
1690
|
|
1528
|
-
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
1691
|
+
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
|
1529
1692
|
llama_io_read_buffer io(src, size);
|
1530
1693
|
try {
|
1531
|
-
return state_seq_read_data(io, seq_id);
|
1694
|
+
return state_seq_read_data(io, seq_id, flags);
|
1532
1695
|
} catch (const std::exception & err) {
|
1533
1696
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
1534
1697
|
return 0;
|
@@ -1626,7 +1789,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
|
|
1626
1789
|
{
|
1627
1790
|
const size_t state_size = file.size() - file.tell();
|
1628
1791
|
llama_io_read_file io(&file);
|
1629
|
-
const size_t nread = state_seq_read_data(io, seq_id);
|
1792
|
+
const size_t nread = state_seq_read_data(io, seq_id, 0);
|
1630
1793
|
if (!nread) {
|
1631
1794
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
1632
1795
|
return 0;
|
@@ -1650,7 +1813,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
|
|
1650
1813
|
|
1651
1814
|
// save the context state using stream saving
|
1652
1815
|
llama_io_write_file io(&file);
|
1653
|
-
state_seq_write_data(io, seq_id);
|
1816
|
+
state_seq_write_data(io, seq_id, 0);
|
1654
1817
|
|
1655
1818
|
const size_t res = file.tell();
|
1656
1819
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
@@ -1679,14 +1842,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
1679
1842
|
|
1680
1843
|
std::vector<int32_t> w_output_pos;
|
1681
1844
|
|
1682
|
-
GGML_ASSERT(n_outputs <= n_outputs_max);
|
1683
|
-
|
1684
1845
|
w_output_pos.resize(n_outputs);
|
1685
1846
|
|
1686
1847
|
// build a more compact representation of the output ids
|
1687
1848
|
for (size_t i = 0; i < n_batch(); ++i) {
|
1688
1849
|
// map an output id to a position in the batch
|
1689
|
-
|
1850
|
+
int64_t pos = output_ids[i];
|
1690
1851
|
if (pos >= 0) {
|
1691
1852
|
GGML_ASSERT(pos < n_outputs);
|
1692
1853
|
w_output_pos[pos] = i;
|
@@ -1726,11 +1887,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
1726
1887
|
}
|
1727
1888
|
}
|
1728
1889
|
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
1733
|
-
kv_self->state_write(io);
|
1890
|
+
if (memory != nullptr) {
|
1891
|
+
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
1892
|
+
memory->state_write(io);
|
1734
1893
|
}
|
1735
1894
|
|
1736
1895
|
return io.n_bytes();
|
@@ -1815,35 +1974,29 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
1815
1974
|
}
|
1816
1975
|
|
1817
1976
|
if (memory) {
|
1818
|
-
LLAMA_LOG_DEBUG("%s: - reading
|
1819
|
-
|
1820
|
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1977
|
+
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
1821
1978
|
|
1822
|
-
|
1979
|
+
memory->state_read(io);
|
1823
1980
|
}
|
1824
1981
|
|
1825
1982
|
return io.n_bytes();
|
1826
1983
|
}
|
1827
1984
|
|
1828
|
-
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
1985
|
+
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
1829
1986
|
GGML_UNUSED(seq_id);
|
1830
1987
|
|
1831
1988
|
if (memory) {
|
1832
|
-
|
1833
|
-
|
1834
|
-
kv_self->state_write(io, seq_id);
|
1989
|
+
memory->state_write(io, seq_id, flags);
|
1835
1990
|
}
|
1836
1991
|
|
1837
1992
|
return io.n_bytes();
|
1838
1993
|
}
|
1839
1994
|
|
1840
|
-
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
1995
|
+
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
1841
1996
|
GGML_UNUSED(seq_id);
|
1842
1997
|
|
1843
1998
|
if (memory) {
|
1844
|
-
|
1845
|
-
|
1846
|
-
kv_self->state_read(io, seq_id);
|
1999
|
+
memory->state_read(io, seq_id, flags);
|
1847
2000
|
}
|
1848
2001
|
|
1849
2002
|
return io.n_bytes();
|
@@ -1862,6 +2015,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
|
|
1862
2015
|
data.t_eval_ms = 1e-3 * t_eval_us;
|
1863
2016
|
data.n_p_eval = std::max(1, n_p_eval);
|
1864
2017
|
data.n_eval = std::max(1, n_eval);
|
2018
|
+
data.n_reused = std::max(0, n_reused);
|
1865
2019
|
|
1866
2020
|
return data;
|
1867
2021
|
}
|
@@ -1870,6 +2024,22 @@ void llama_context::perf_reset() {
|
|
1870
2024
|
t_start_us = ggml_time_us();
|
1871
2025
|
t_eval_us = n_eval = 0;
|
1872
2026
|
t_p_eval_us = n_p_eval = 0;
|
2027
|
+
n_reused = 0;
|
2028
|
+
}
|
2029
|
+
|
2030
|
+
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
2031
|
+
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
2032
|
+
for (const auto & buft_size : model.memory_breakdown()) {
|
2033
|
+
ret[buft_size.first].model += buft_size.second;
|
2034
|
+
}
|
2035
|
+
for (const auto & buft_size : memory->memory_breakdown()) {
|
2036
|
+
ret[buft_size.first].context += buft_size.second;
|
2037
|
+
}
|
2038
|
+
for (const auto & backend_ptr : backends) {
|
2039
|
+
ggml_backend_t backend = backend_ptr.get();
|
2040
|
+
ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
2041
|
+
}
|
2042
|
+
return ret;
|
1873
2043
|
}
|
1874
2044
|
|
1875
2045
|
//
|
@@ -1904,7 +2074,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
|
|
1904
2074
|
opt_params.opt_period = n_batch / n_ubatch;
|
1905
2075
|
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
1906
2076
|
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
1907
|
-
|
2077
|
+
opt_params.optimizer = lopt_params.optimizer_type;
|
1908
2078
|
opt_ctx = ggml_opt_init(opt_params);
|
1909
2079
|
|
1910
2080
|
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
@@ -1948,10 +2118,7 @@ void llama_context::opt_epoch_iter(
|
|
1948
2118
|
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
1949
2119
|
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
1950
2120
|
|
1951
|
-
|
1952
|
-
|
1953
|
-
kv_self->clear();
|
1954
|
-
llama_kv_cache_guard kv_guard(kv_self);
|
2121
|
+
memory->clear(true);
|
1955
2122
|
|
1956
2123
|
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
1957
2124
|
batch.n_tokens = n_batch;
|
@@ -1963,39 +2130,49 @@ void llama_context::opt_epoch_iter(
|
|
1963
2130
|
batch.logits [pos_batch] = true;
|
1964
2131
|
}
|
1965
2132
|
|
1966
|
-
|
2133
|
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
2134
|
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
2135
|
+
return;
|
2136
|
+
}
|
1967
2137
|
|
1968
|
-
|
2138
|
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
1969
2139
|
|
1970
|
-
|
1971
|
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
2140
|
+
n_queued_tokens += n_tokens_all;
|
1972
2141
|
|
1973
2142
|
embd_seq.clear();
|
1974
2143
|
|
1975
|
-
|
2144
|
+
uint32_t n_outputs_all = n_tokens_all;
|
1976
2145
|
|
1977
|
-
|
2146
|
+
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
2147
|
+
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
2148
|
+
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
2149
|
+
break;
|
2150
|
+
}
|
1978
2151
|
|
1979
2152
|
// reserve output buffer
|
1980
2153
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
1981
|
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
2154
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
1982
2155
|
GGML_ABORT("TODO: handle this error");
|
1983
2156
|
};
|
1984
2157
|
|
1985
|
-
|
1986
|
-
|
2158
|
+
uint32_t pos_batch = 0;
|
2159
|
+
do {
|
2160
|
+
const auto & ubatch = mctx->get_ubatch();
|
1987
2161
|
|
1988
2162
|
n_outputs = ubatch.n_tokens;
|
1989
2163
|
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
GGML_ABORT("TODO: handle this error");
|
2164
|
+
if (!mctx->apply()) {
|
2165
|
+
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
2166
|
+
break;
|
1995
2167
|
}
|
1996
2168
|
|
1997
|
-
auto *
|
1998
|
-
|
2169
|
+
auto * res = gf_res_prev.get();
|
2170
|
+
|
2171
|
+
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
|
2172
|
+
|
2173
|
+
res->reset();
|
2174
|
+
|
2175
|
+
auto * gf = model.build_graph(gparams);
|
1999
2176
|
|
2000
2177
|
struct ggml_context * ctx_compute_opt;
|
2001
2178
|
{
|
@@ -2010,6 +2187,7 @@ void llama_context::opt_epoch_iter(
|
|
2010
2187
|
}
|
2011
2188
|
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
2012
2189
|
ggml_opt_alloc(opt_ctx, train);
|
2190
|
+
|
2013
2191
|
res->set_inputs(&ubatch);
|
2014
2192
|
{
|
2015
2193
|
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
@@ -2027,10 +2205,10 @@ void llama_context::opt_epoch_iter(
|
|
2027
2205
|
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
2028
2206
|
}
|
2029
2207
|
ggml_free(ctx_compute_opt);
|
2030
|
-
}
|
2031
|
-
}
|
2032
2208
|
|
2033
|
-
|
2209
|
+
pos_batch += ubatch.n_tokens;
|
2210
|
+
} while (mctx->next());
|
2211
|
+
}
|
2034
2212
|
}
|
2035
2213
|
|
2036
2214
|
void llama_context::opt_epoch(
|
@@ -2096,12 +2274,13 @@ llama_context_params llama_context_default_params() {
|
|
2096
2274
|
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
2097
2275
|
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
2098
2276
|
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
2277
|
+
/*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
|
2099
2278
|
/*.rope_freq_base =*/ 0.0f,
|
2100
2279
|
/*.rope_freq_scale =*/ 0.0f,
|
2101
2280
|
/*.yarn_ext_factor =*/ -1.0f,
|
2102
|
-
/*.yarn_attn_factor =*/ 1.0f,
|
2103
|
-
/*.yarn_beta_fast =*/
|
2104
|
-
/*.yarn_beta_slow =*/ 1.0f,
|
2281
|
+
/*.yarn_attn_factor =*/ -1.0f,
|
2282
|
+
/*.yarn_beta_fast =*/ -1.0f,
|
2283
|
+
/*.yarn_beta_slow =*/ -1.0f,
|
2105
2284
|
/*.yarn_orig_ctx =*/ 0,
|
2106
2285
|
/*.defrag_thold =*/ -1.0f,
|
2107
2286
|
/*.cb_eval =*/ nullptr,
|
@@ -2112,10 +2291,10 @@ llama_context_params llama_context_default_params() {
|
|
2112
2291
|
/*.abort_callback_data =*/ nullptr,
|
2113
2292
|
/*.embeddings =*/ false,
|
2114
2293
|
/*.offload_kqv =*/ true,
|
2115
|
-
/*.flash_attn =*/ false,
|
2116
2294
|
/*.no_perf =*/ true,
|
2117
2295
|
/*.op_offload =*/ true,
|
2118
2296
|
/*.swa_full =*/ true,
|
2297
|
+
/*.kv_unified =*/ false,
|
2119
2298
|
};
|
2120
2299
|
|
2121
2300
|
return result;
|
@@ -2139,12 +2318,30 @@ llama_context * llama_init_from_model(
|
|
2139
2318
|
return nullptr;
|
2140
2319
|
}
|
2141
2320
|
|
2142
|
-
if (params.
|
2321
|
+
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
|
2143
2322
|
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
2144
|
-
params.
|
2323
|
+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
2324
|
+
}
|
2325
|
+
|
2326
|
+
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
2327
|
+
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
2328
|
+
if (model->hparams.n_embd_head_k % blck_size != 0) {
|
2329
|
+
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
2330
|
+
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
|
2331
|
+
return nullptr;
|
2332
|
+
}
|
2333
|
+
}
|
2334
|
+
|
2335
|
+
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
2336
|
+
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
2337
|
+
if (model->hparams.n_embd_head_v % blck_size != 0) {
|
2338
|
+
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
2339
|
+
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
|
2340
|
+
return nullptr;
|
2341
|
+
}
|
2145
2342
|
}
|
2146
2343
|
|
2147
|
-
if (ggml_is_quantized(params.type_v) &&
|
2344
|
+
if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
|
2148
2345
|
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
2149
2346
|
return nullptr;
|
2150
2347
|
}
|
@@ -2190,14 +2387,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|
2190
2387
|
return &ctx->get_model();
|
2191
2388
|
}
|
2192
2389
|
|
2193
|
-
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
2194
|
-
return ctx->get_kv_self();
|
2195
|
-
}
|
2196
|
-
|
2197
|
-
void llama_kv_self_update(llama_context * ctx) {
|
2198
|
-
ctx->kv_self_update();
|
2199
|
-
}
|
2200
|
-
|
2201
2390
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
2202
2391
|
return ctx->pooling_type();
|
2203
2392
|
}
|
@@ -2311,160 +2500,108 @@ int32_t llama_apply_adapter_cvec(
|
|
2311
2500
|
}
|
2312
2501
|
|
2313
2502
|
//
|
2314
|
-
//
|
2503
|
+
// memory
|
2315
2504
|
//
|
2316
2505
|
|
2317
|
-
|
2318
|
-
|
2319
|
-
const auto * kv = ctx->get_kv_self();
|
2320
|
-
if (!kv) {
|
2321
|
-
return 0;
|
2322
|
-
}
|
2323
|
-
|
2324
|
-
int32_t res = 0;
|
2325
|
-
|
2326
|
-
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
2327
|
-
const llama_pos p0 = kv->seq_pos_min(s);
|
2328
|
-
const llama_pos p1 = kv->seq_pos_max(s);
|
2329
|
-
|
2330
|
-
if (p0 >= 0) {
|
2331
|
-
res += (p1 - p0) + 1;
|
2332
|
-
}
|
2333
|
-
}
|
2334
|
-
|
2335
|
-
return res;
|
2336
|
-
}
|
2337
|
-
|
2338
|
-
// deprecated
|
2339
|
-
// note: this is the same as above - will be removed anyway, so it's ok
|
2340
|
-
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
2341
|
-
const auto * kv = ctx->get_kv_self();
|
2342
|
-
if (!kv) {
|
2343
|
-
return 0;
|
2344
|
-
}
|
2345
|
-
|
2346
|
-
int32_t res = 0;
|
2347
|
-
|
2348
|
-
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
2349
|
-
const llama_pos p0 = kv->seq_pos_min(s);
|
2350
|
-
const llama_pos p1 = kv->seq_pos_max(s);
|
2351
|
-
|
2352
|
-
if (p0 >= 0) {
|
2353
|
-
res += (p1 - p0) + 1;
|
2354
|
-
}
|
2355
|
-
}
|
2356
|
-
|
2357
|
-
return res;
|
2506
|
+
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
2507
|
+
return ctx->get_memory();
|
2358
2508
|
}
|
2359
2509
|
|
2360
|
-
void
|
2361
|
-
|
2362
|
-
if (!kv) {
|
2510
|
+
void llama_memory_clear(llama_memory_t mem, bool data) {
|
2511
|
+
if (!mem) {
|
2363
2512
|
return;
|
2364
2513
|
}
|
2365
2514
|
|
2366
|
-
|
2515
|
+
mem->clear(data);
|
2367
2516
|
}
|
2368
2517
|
|
2369
|
-
bool
|
2370
|
-
|
2371
|
-
|
2372
|
-
|
2373
|
-
|
2374
|
-
|
2375
|
-
if (!kv) {
|
2518
|
+
bool llama_memory_seq_rm(
|
2519
|
+
llama_memory_t mem,
|
2520
|
+
llama_seq_id seq_id,
|
2521
|
+
llama_pos p0,
|
2522
|
+
llama_pos p1) {
|
2523
|
+
if (!mem) {
|
2376
2524
|
return true;
|
2377
2525
|
}
|
2378
2526
|
|
2379
|
-
return
|
2527
|
+
return mem->seq_rm(seq_id, p0, p1);
|
2380
2528
|
}
|
2381
2529
|
|
2382
|
-
void
|
2383
|
-
|
2384
|
-
|
2385
|
-
|
2386
|
-
|
2387
|
-
|
2388
|
-
|
2389
|
-
if (!kv) {
|
2530
|
+
void llama_memory_seq_cp(
|
2531
|
+
llama_memory_t mem,
|
2532
|
+
llama_seq_id seq_id_src,
|
2533
|
+
llama_seq_id seq_id_dst,
|
2534
|
+
llama_pos p0,
|
2535
|
+
llama_pos p1) {
|
2536
|
+
if (!mem) {
|
2390
2537
|
return;
|
2391
2538
|
}
|
2392
2539
|
|
2393
|
-
|
2540
|
+
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
2394
2541
|
}
|
2395
2542
|
|
2396
|
-
void
|
2397
|
-
|
2398
|
-
|
2543
|
+
void llama_memory_seq_keep(
|
2544
|
+
llama_memory_t mem,
|
2545
|
+
llama_seq_id seq_id) {
|
2546
|
+
if (!mem) {
|
2399
2547
|
return;
|
2400
2548
|
}
|
2401
2549
|
|
2402
|
-
|
2550
|
+
mem->seq_keep(seq_id);
|
2403
2551
|
}
|
2404
2552
|
|
2405
|
-
void
|
2406
|
-
|
2407
|
-
|
2408
|
-
|
2409
|
-
|
2410
|
-
|
2411
|
-
|
2412
|
-
if (!kv) {
|
2553
|
+
void llama_memory_seq_add(
|
2554
|
+
llama_memory_t mem,
|
2555
|
+
llama_seq_id seq_id,
|
2556
|
+
llama_pos p0,
|
2557
|
+
llama_pos p1,
|
2558
|
+
llama_pos delta) {
|
2559
|
+
if (!mem) {
|
2413
2560
|
return;
|
2414
2561
|
}
|
2415
2562
|
|
2416
|
-
|
2563
|
+
mem->seq_add(seq_id, p0, p1, delta);
|
2417
2564
|
}
|
2418
2565
|
|
2419
|
-
void
|
2420
|
-
|
2421
|
-
|
2422
|
-
|
2423
|
-
|
2424
|
-
|
2425
|
-
|
2426
|
-
if (!kv) {
|
2566
|
+
void llama_memory_seq_div(
|
2567
|
+
llama_memory_t mem,
|
2568
|
+
llama_seq_id seq_id,
|
2569
|
+
llama_pos p0,
|
2570
|
+
llama_pos p1,
|
2571
|
+
int d) {
|
2572
|
+
if (!mem) {
|
2427
2573
|
return;
|
2428
2574
|
}
|
2429
2575
|
|
2430
|
-
|
2576
|
+
mem->seq_div(seq_id, p0, p1, d);
|
2431
2577
|
}
|
2432
2578
|
|
2433
|
-
llama_pos
|
2434
|
-
|
2435
|
-
|
2579
|
+
llama_pos llama_memory_seq_pos_min(
|
2580
|
+
llama_memory_t mem,
|
2581
|
+
llama_seq_id seq_id) {
|
2582
|
+
if (!mem) {
|
2436
2583
|
return -1;
|
2437
2584
|
}
|
2438
2585
|
|
2439
|
-
return
|
2586
|
+
return mem->seq_pos_min(seq_id);
|
2440
2587
|
}
|
2441
2588
|
|
2442
|
-
llama_pos
|
2443
|
-
|
2444
|
-
|
2589
|
+
llama_pos llama_memory_seq_pos_max(
|
2590
|
+
llama_memory_t mem,
|
2591
|
+
llama_seq_id seq_id) {
|
2592
|
+
if (!mem) {
|
2445
2593
|
return -1;
|
2446
2594
|
}
|
2447
2595
|
|
2448
|
-
return
|
2596
|
+
return mem->seq_pos_max(seq_id);
|
2449
2597
|
}
|
2450
2598
|
|
2451
|
-
|
2452
|
-
|
2453
|
-
if (!kv) {
|
2454
|
-
return;
|
2455
|
-
}
|
2456
|
-
|
2457
|
-
// force defrag
|
2458
|
-
kv->defrag_sched(-1.0f);
|
2459
|
-
}
|
2460
|
-
|
2461
|
-
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
2462
|
-
const auto * kv = ctx->get_kv_self();
|
2463
|
-
if (!kv) {
|
2599
|
+
bool llama_memory_can_shift(llama_memory_t mem) {
|
2600
|
+
if (!mem) {
|
2464
2601
|
return false;
|
2465
2602
|
}
|
2466
2603
|
|
2467
|
-
return
|
2604
|
+
return mem->get_can_shift();
|
2468
2605
|
}
|
2469
2606
|
|
2470
2607
|
// llama state API
|
@@ -2536,19 +2673,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
|
|
2536
2673
|
}
|
2537
2674
|
|
2538
2675
|
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
2539
|
-
return ctx
|
2676
|
+
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
|
2540
2677
|
}
|
2541
2678
|
|
2542
2679
|
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
2680
|
+
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
|
2681
|
+
}
|
2682
|
+
|
2683
|
+
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
2684
|
+
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
|
2685
|
+
}
|
2686
|
+
|
2687
|
+
size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
2688
|
+
return ctx->state_seq_get_size(seq_id, flags);
|
2689
|
+
}
|
2690
|
+
|
2691
|
+
size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
2543
2692
|
ctx->synchronize();
|
2544
2693
|
|
2545
|
-
return ctx->state_seq_get_data(seq_id, dst, size);
|
2694
|
+
return ctx->state_seq_get_data(seq_id, dst, size, flags);
|
2546
2695
|
}
|
2547
2696
|
|
2548
|
-
size_t
|
2697
|
+
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
2549
2698
|
ctx->synchronize();
|
2550
2699
|
|
2551
|
-
return ctx->state_seq_set_data(seq_id, src, size);
|
2700
|
+
return ctx->state_seq_set_data(seq_id, src, size, flags);
|
2552
2701
|
}
|
2553
2702
|
|
2554
2703
|
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
@@ -2589,22 +2738,8 @@ int32_t llama_encode(
|
|
2589
2738
|
int32_t llama_decode(
|
2590
2739
|
llama_context * ctx,
|
2591
2740
|
llama_batch batch) {
|
2592
|
-
int ret = ctx->decode(batch);
|
2593
|
-
|
2594
|
-
// defrag and try again
|
2595
|
-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
2596
|
-
if (ret == 1) {
|
2597
|
-
llama_kv_self_defrag(ctx);
|
2598
|
-
ret = ctx->decode(batch);
|
2599
|
-
|
2600
|
-
if (ret == 1) {
|
2601
|
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
2602
|
-
|
2603
|
-
return ret;
|
2604
|
-
}
|
2605
|
-
}
|
2606
|
-
|
2607
|
-
if (ret != 0) {
|
2741
|
+
const int ret = ctx->decode(batch);
|
2742
|
+
if (ret != 0 && ret != 1) {
|
2608
2743
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
2609
2744
|
}
|
2610
2745
|
|
@@ -2638,12 +2773,149 @@ void llama_perf_context_print(const llama_context * ctx) {
|
|
2638
2773
|
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
2639
2774
|
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
2640
2775
|
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
2776
|
+
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
2641
2777
|
}
|
2642
2778
|
|
2643
2779
|
void llama_perf_context_reset(llama_context * ctx) {
|
2644
2780
|
ctx->perf_reset();
|
2645
2781
|
}
|
2646
2782
|
|
2783
|
+
void llama_memory_breakdown_print(const struct llama_context * ctx) {
|
2784
|
+
const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
|
2785
|
+
|
2786
|
+
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
|
2787
|
+
|
2788
|
+
std::vector<std::array<std::string, 9>> table_data;
|
2789
|
+
table_data.reserve(devices.size());
|
2790
|
+
const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n";
|
2791
|
+
const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n";
|
2792
|
+
const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n";
|
2793
|
+
|
2794
|
+
table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"});
|
2795
|
+
|
2796
|
+
constexpr size_t MiB = 1024 * 1024;
|
2797
|
+
const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "};
|
2798
|
+
|
2799
|
+
// track seen buffer types to avoid double counting:
|
2800
|
+
std::set<ggml_backend_buffer_type_t> seen_buffer_types;
|
2801
|
+
|
2802
|
+
// accumulative memory breakdown for each device and for host:
|
2803
|
+
std::vector<llama_memory_breakdown_data> mb_dev(devices.size());
|
2804
|
+
llama_memory_breakdown_data mb_host;
|
2805
|
+
|
2806
|
+
for (const auto & buft_mb : memory_breakdown) {
|
2807
|
+
ggml_backend_buffer_type_t buft = buft_mb.first;
|
2808
|
+
const llama_memory_breakdown_data & mb = buft_mb.second;
|
2809
|
+
if (ggml_backend_buft_is_host(buft)) {
|
2810
|
+
mb_host.model += mb.model;
|
2811
|
+
mb_host.context += mb.context;
|
2812
|
+
mb_host.compute += mb.compute;
|
2813
|
+
seen_buffer_types.insert(buft);
|
2814
|
+
continue;
|
2815
|
+
}
|
2816
|
+
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
|
2817
|
+
if (dev) {
|
2818
|
+
int i_dev = -1;
|
2819
|
+
for (size_t i = 0; i < devices.size(); i++) {
|
2820
|
+
if (devices[i] == dev) {
|
2821
|
+
i_dev = i;
|
2822
|
+
break;
|
2823
|
+
}
|
2824
|
+
}
|
2825
|
+
if (i_dev != -1) {
|
2826
|
+
mb_dev[i_dev].model += mb.model;
|
2827
|
+
mb_dev[i_dev].context += mb.context;
|
2828
|
+
mb_dev[i_dev].compute += mb.compute;
|
2829
|
+
seen_buffer_types.insert(buft);
|
2830
|
+
continue;
|
2831
|
+
}
|
2832
|
+
}
|
2833
|
+
}
|
2834
|
+
|
2835
|
+
// print memory breakdown for each device:
|
2836
|
+
for (size_t i = 0; i < devices.size(); i++) {
|
2837
|
+
ggml_backend_dev_t dev = devices[i];
|
2838
|
+
llama_memory_breakdown_data mb = mb_dev[i];
|
2839
|
+
|
2840
|
+
const std::string name = ggml_backend_dev_name(dev);
|
2841
|
+
std::string desc = ggml_backend_dev_description(dev);
|
2842
|
+
for (const std::string & prefix : desc_prefixes_strip) {
|
2843
|
+
if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) {
|
2844
|
+
desc = desc.substr(prefix.length());
|
2845
|
+
}
|
2846
|
+
}
|
2847
|
+
|
2848
|
+
size_t free, total;
|
2849
|
+
ggml_backend_dev_memory(dev, &free, &total);
|
2850
|
+
|
2851
|
+
const size_t self = mb.model + mb.context + mb.compute;
|
2852
|
+
const size_t unaccounted = total - self - free;
|
2853
|
+
|
2854
|
+
table_data.push_back({
|
2855
|
+
template_gpu,
|
2856
|
+
" - " + name + " (" + desc + ")",
|
2857
|
+
std::to_string(total / MiB),
|
2858
|
+
std::to_string(free / MiB),
|
2859
|
+
std::to_string(self / MiB),
|
2860
|
+
std::to_string(mb.model / MiB),
|
2861
|
+
std::to_string(mb.context / MiB),
|
2862
|
+
std::to_string(mb.compute / MiB),
|
2863
|
+
std::to_string(unaccounted / MiB)});
|
2864
|
+
}
|
2865
|
+
|
2866
|
+
// print memory breakdown for host:
|
2867
|
+
{
|
2868
|
+
const size_t self = mb_host.model + mb_host.context + mb_host.compute;
|
2869
|
+
table_data.push_back({
|
2870
|
+
template_other,
|
2871
|
+
" - Host",
|
2872
|
+
"", // total
|
2873
|
+
"", // free
|
2874
|
+
std::to_string(self / MiB),
|
2875
|
+
std::to_string(mb_host.model / MiB),
|
2876
|
+
std::to_string(mb_host.context / MiB),
|
2877
|
+
std::to_string(mb_host.compute / MiB),
|
2878
|
+
""}); // unaccounted
|
2879
|
+
}
|
2880
|
+
|
2881
|
+
// print memory breakdown for all remaining buffer types:
|
2882
|
+
for (const auto & buft_mb : memory_breakdown) {
|
2883
|
+
ggml_backend_buffer_type_t buft = buft_mb.first;
|
2884
|
+
const llama_memory_breakdown_data & mb = buft_mb.second;
|
2885
|
+
if (seen_buffer_types.count(buft) == 1) {
|
2886
|
+
continue;
|
2887
|
+
}
|
2888
|
+
const std::string name = ggml_backend_buft_name(buft);
|
2889
|
+
const size_t self = mb.model + mb.context + mb.compute;
|
2890
|
+
table_data.push_back({
|
2891
|
+
template_other,
|
2892
|
+
" - " + name,
|
2893
|
+
"", // total
|
2894
|
+
"", // free
|
2895
|
+
std::to_string(self / MiB),
|
2896
|
+
std::to_string(mb.model / MiB),
|
2897
|
+
std::to_string(mb.context / MiB),
|
2898
|
+
std::to_string(mb.compute / MiB),
|
2899
|
+
""}); // unaccounted
|
2900
|
+
seen_buffer_types.insert(buft);
|
2901
|
+
}
|
2902
|
+
|
2903
|
+
for (size_t j = 1; j < table_data[0].size(); j++) {
|
2904
|
+
size_t max_len = 0;
|
2905
|
+
for (const auto & td : table_data) {
|
2906
|
+
max_len = std::max(max_len, td[j].length());
|
2907
|
+
}
|
2908
|
+
for (auto & td : table_data) {
|
2909
|
+
td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' ');
|
2910
|
+
}
|
2911
|
+
}
|
2912
|
+
for (const auto & td : table_data) {
|
2913
|
+
LLAMA_LOG_INFO(td[0].c_str(),
|
2914
|
+
__func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(),
|
2915
|
+
td[6].c_str(), td[7].c_str(), td[8].c_str());
|
2916
|
+
}
|
2917
|
+
}
|
2918
|
+
|
2647
2919
|
//
|
2648
2920
|
// training
|
2649
2921
|
//
|