@novastera-oss/llamarn 0.2.6 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +141 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +58 -24
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +37 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +53 -40
- package/cpp/llama.cpp/common/common.h +6 -2
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
- package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +141 -38
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
- package/cpp/llama.cpp/src/llama-arch.h +25 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
- package/cpp/llama.cpp/src/llama-batch.h +110 -57
- package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +360 -266
- package/cpp/llama.cpp/src/llama-context.h +27 -23
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
- package/cpp/llama.cpp/src/llama-graph.h +126 -58
- package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
- package/cpp/llama.cpp/src/llama-hparams.h +16 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
- package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +73 -36
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
- package/cpp/llama.cpp/src/llama-model.h +26 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
- package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +6 -2
- package/ios/include/llama.h +141 -38
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
#include "llama-batch.h"
|
|
4
4
|
#include "llama-graph.h"
|
|
5
|
-
#include "llama-kv-cache.h"
|
|
6
5
|
#include "llama-kv-cells.h"
|
|
6
|
+
#include "llama-memory.h"
|
|
7
7
|
|
|
8
8
|
#include <unordered_map>
|
|
9
9
|
#include <vector>
|
|
@@ -17,13 +17,26 @@ struct llama_context;
|
|
|
17
17
|
// llama_kv_cache_unified
|
|
18
18
|
//
|
|
19
19
|
|
|
20
|
-
class llama_kv_cache_unified : public
|
|
20
|
+
class llama_kv_cache_unified : public llama_memory_i {
|
|
21
21
|
public:
|
|
22
22
|
static uint32_t get_padding(const llama_cparams & cparams);
|
|
23
23
|
|
|
24
24
|
// this callback is used to filter out layers that should not be included in the cache
|
|
25
25
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
|
26
26
|
|
|
27
|
+
using ubatch_heads = std::vector<uint32_t>;
|
|
28
|
+
|
|
29
|
+
struct defrag_info {
|
|
30
|
+
bool empty() const {
|
|
31
|
+
return ids.empty();
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
// contains information about which cell moves where:
|
|
35
|
+
// - cell i moves to ids[i]
|
|
36
|
+
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
|
37
|
+
std::vector<uint32_t> ids;
|
|
38
|
+
};
|
|
39
|
+
|
|
27
40
|
llama_kv_cache_unified(
|
|
28
41
|
const llama_model & model,
|
|
29
42
|
layer_filter_cb && filter,
|
|
@@ -43,7 +56,18 @@ public:
|
|
|
43
56
|
// llama_memory_i
|
|
44
57
|
//
|
|
45
58
|
|
|
46
|
-
|
|
59
|
+
llama_memory_context_ptr init_batch(
|
|
60
|
+
llama_batch_allocr & balloc,
|
|
61
|
+
uint32_t n_ubatch,
|
|
62
|
+
bool embd_all) override;
|
|
63
|
+
|
|
64
|
+
llama_memory_context_ptr init_full() override;
|
|
65
|
+
|
|
66
|
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
67
|
+
|
|
68
|
+
bool get_can_shift() const override;
|
|
69
|
+
|
|
70
|
+
void clear(bool data) override;
|
|
47
71
|
|
|
48
72
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
49
73
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
@@ -54,24 +78,6 @@ public:
|
|
|
54
78
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
55
79
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
56
80
|
|
|
57
|
-
//
|
|
58
|
-
// llama_kv_cache
|
|
59
|
-
//
|
|
60
|
-
|
|
61
|
-
llama_memory_state_ptr init_batch(
|
|
62
|
-
const llama_batch & batch,
|
|
63
|
-
uint32_t n_ubatch,
|
|
64
|
-
bool embd_pooled,
|
|
65
|
-
bool logits_all) override;
|
|
66
|
-
|
|
67
|
-
llama_memory_state_ptr init_full() override;
|
|
68
|
-
|
|
69
|
-
bool update(llama_context & lctx) override;
|
|
70
|
-
|
|
71
|
-
void defrag_sched(float thold) override;
|
|
72
|
-
|
|
73
|
-
bool get_can_shift() const override;
|
|
74
|
-
|
|
75
81
|
// state write/load
|
|
76
82
|
|
|
77
83
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
@@ -83,6 +89,8 @@ public:
|
|
|
83
89
|
|
|
84
90
|
uint32_t get_size() const;
|
|
85
91
|
|
|
92
|
+
bool get_has_shift() const;
|
|
93
|
+
|
|
86
94
|
//
|
|
87
95
|
// graph_build API
|
|
88
96
|
//
|
|
@@ -103,7 +111,9 @@ public:
|
|
|
103
111
|
|
|
104
112
|
// find places for the provided ubatches in the cache, returns the head locations
|
|
105
113
|
// return empty vector on failure
|
|
106
|
-
|
|
114
|
+
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
|
115
|
+
|
|
116
|
+
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
|
107
117
|
|
|
108
118
|
// return the cell position where we can insert the ubatch
|
|
109
119
|
// return -1 on failure to find a contiguous slot of kv cells
|
|
@@ -133,8 +143,7 @@ private:
|
|
|
133
143
|
ggml_tensor * v;
|
|
134
144
|
};
|
|
135
145
|
|
|
136
|
-
bool
|
|
137
|
-
bool v_trans = true; // the value tensor is transposed
|
|
146
|
+
bool v_trans = true; // the value tensor is transposed
|
|
138
147
|
|
|
139
148
|
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
|
140
149
|
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
|
@@ -148,6 +157,8 @@ private:
|
|
|
148
157
|
// SWA
|
|
149
158
|
const uint32_t n_swa = 0;
|
|
150
159
|
|
|
160
|
+
int debug = 0;
|
|
161
|
+
|
|
151
162
|
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
|
152
163
|
|
|
153
164
|
std::vector<ggml_context_ptr> ctxs;
|
|
@@ -160,13 +171,8 @@ private:
|
|
|
160
171
|
// model layer id -> KV cache layer id
|
|
161
172
|
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
|
162
173
|
|
|
163
|
-
//
|
|
164
|
-
|
|
165
|
-
std::vector<uint32_t> ids;
|
|
166
|
-
} defrag_info;
|
|
167
|
-
|
|
168
|
-
// return true if cells have been moved
|
|
169
|
-
bool defrag_prepare(int32_t n_max_nodes);
|
|
174
|
+
// return non-empty vector if cells have been moved
|
|
175
|
+
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
|
170
176
|
|
|
171
177
|
size_t total_size() const;
|
|
172
178
|
|
|
@@ -192,7 +198,8 @@ private:
|
|
|
192
198
|
llm_graph_result_ptr build_graph_defrag(
|
|
193
199
|
const llama_cparams & cparams,
|
|
194
200
|
ggml_context * ctx,
|
|
195
|
-
ggml_cgraph * gf
|
|
201
|
+
ggml_cgraph * gf,
|
|
202
|
+
const defrag_info & dinfo) const;
|
|
196
203
|
|
|
197
204
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
|
198
205
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
@@ -201,40 +208,46 @@ private:
|
|
|
201
208
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
|
202
209
|
};
|
|
203
210
|
|
|
204
|
-
class
|
|
211
|
+
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
|
205
212
|
public:
|
|
213
|
+
// some shorthands
|
|
214
|
+
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
|
215
|
+
using defrag_info = llama_kv_cache_unified::defrag_info;
|
|
216
|
+
|
|
206
217
|
// used for errors
|
|
207
|
-
|
|
218
|
+
llama_kv_cache_unified_context(llama_memory_status status);
|
|
208
219
|
|
|
209
|
-
// used to create a full-cache
|
|
210
|
-
|
|
211
|
-
llama_memory_status status,
|
|
220
|
+
// used to create a full-cache context
|
|
221
|
+
llama_kv_cache_unified_context(
|
|
212
222
|
llama_kv_cache_unified * kv);
|
|
213
223
|
|
|
214
|
-
// used to create
|
|
215
|
-
|
|
216
|
-
|
|
224
|
+
// used to create an update context
|
|
225
|
+
llama_kv_cache_unified_context(
|
|
226
|
+
llama_kv_cache_unified * kv,
|
|
227
|
+
llama_context * lctx,
|
|
228
|
+
bool do_shift,
|
|
229
|
+
defrag_info dinfo);
|
|
230
|
+
|
|
231
|
+
// used to create a batch procesing context from a batch
|
|
232
|
+
llama_kv_cache_unified_context(
|
|
217
233
|
llama_kv_cache_unified * kv,
|
|
218
|
-
|
|
219
|
-
std::vector<uint32_t> heads,
|
|
234
|
+
ubatch_heads heads,
|
|
220
235
|
std::vector<llama_ubatch> ubatches);
|
|
221
236
|
|
|
222
|
-
virtual ~
|
|
237
|
+
virtual ~llama_kv_cache_unified_context();
|
|
223
238
|
|
|
224
239
|
//
|
|
225
|
-
//
|
|
240
|
+
// llama_memory_context_i
|
|
226
241
|
//
|
|
227
242
|
|
|
228
243
|
bool next() override;
|
|
229
244
|
bool apply() override;
|
|
230
245
|
|
|
231
|
-
std::vector<int64_t> & out_ids() override;
|
|
232
|
-
|
|
233
246
|
llama_memory_status get_status() const override;
|
|
234
247
|
const llama_ubatch & get_ubatch() const override;
|
|
235
248
|
|
|
236
249
|
//
|
|
237
|
-
//
|
|
250
|
+
// llama_kv_cache_unified_context specific API
|
|
238
251
|
//
|
|
239
252
|
|
|
240
253
|
uint32_t get_n_kv() const;
|
|
@@ -253,16 +266,28 @@ public:
|
|
|
253
266
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
|
254
267
|
|
|
255
268
|
private:
|
|
256
|
-
|
|
269
|
+
llama_memory_status status;
|
|
257
270
|
|
|
258
271
|
llama_kv_cache_unified * kv;
|
|
272
|
+
llama_context * lctx;
|
|
273
|
+
|
|
274
|
+
//
|
|
275
|
+
// update context
|
|
276
|
+
//
|
|
277
|
+
|
|
278
|
+
bool do_shift = false;
|
|
259
279
|
|
|
260
|
-
|
|
280
|
+
defrag_info dinfo;
|
|
281
|
+
|
|
282
|
+
//
|
|
283
|
+
// batch processing context
|
|
284
|
+
//
|
|
261
285
|
|
|
262
286
|
// the index of the next ubatch to process
|
|
263
287
|
size_t i_next = 0;
|
|
264
288
|
|
|
265
|
-
|
|
289
|
+
ubatch_heads heads;
|
|
290
|
+
|
|
266
291
|
std::vector<llama_ubatch> ubatches;
|
|
267
292
|
|
|
268
293
|
//
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
#include <cassert>
|
|
8
8
|
#include <vector>
|
|
9
9
|
#include <set>
|
|
10
|
+
#include <map>
|
|
10
11
|
|
|
11
12
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
|
12
13
|
// TODO: add unit tests
|
|
@@ -23,7 +24,7 @@ public:
|
|
|
23
24
|
|
|
24
25
|
used.clear();
|
|
25
26
|
|
|
26
|
-
for (uint32_t s = 0; s <
|
|
27
|
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
27
28
|
seq_pos[s].clear();
|
|
28
29
|
}
|
|
29
30
|
}
|
|
@@ -80,6 +81,9 @@ public:
|
|
|
80
81
|
assert(isrc < pos.size());
|
|
81
82
|
assert(idst < pos.size());
|
|
82
83
|
|
|
84
|
+
assert(pos[idst] == -1);
|
|
85
|
+
assert(pos[isrc] != -1);
|
|
86
|
+
|
|
83
87
|
pos [idst] = pos [isrc];
|
|
84
88
|
shift[idst] = shift[isrc];
|
|
85
89
|
seq [idst] = seq [isrc];
|
|
@@ -144,9 +148,10 @@ public:
|
|
|
144
148
|
assert(pos[i] != -1);
|
|
145
149
|
|
|
146
150
|
seq_pos_rm(i);
|
|
151
|
+
seq[i].reset();
|
|
147
152
|
|
|
148
153
|
pos[i] = -1;
|
|
149
|
-
|
|
154
|
+
shift[i] = 0;
|
|
150
155
|
|
|
151
156
|
used.erase(i);
|
|
152
157
|
}
|
|
@@ -160,10 +165,11 @@ public:
|
|
|
160
165
|
assert(seq_id >= 0);
|
|
161
166
|
|
|
162
167
|
seq[i].reset(seq_id);
|
|
163
|
-
|
|
168
|
+
seq_pos_dec(seq_id, pos[i]);
|
|
164
169
|
|
|
165
170
|
if (seq[i].none()) {
|
|
166
171
|
pos[i] = -1;
|
|
172
|
+
shift[i] = 0;
|
|
167
173
|
|
|
168
174
|
used.erase(i);
|
|
169
175
|
|
|
@@ -182,7 +188,7 @@ public:
|
|
|
182
188
|
seq[i].reset();
|
|
183
189
|
|
|
184
190
|
seq[i].set(seq_id);
|
|
185
|
-
|
|
191
|
+
seq_pos_inc(seq_id, pos[i]);
|
|
186
192
|
|
|
187
193
|
return false;
|
|
188
194
|
}
|
|
@@ -192,6 +198,7 @@ public:
|
|
|
192
198
|
seq[i].reset();
|
|
193
199
|
|
|
194
200
|
pos[i] = -1;
|
|
201
|
+
shift[i] = 0;
|
|
195
202
|
|
|
196
203
|
used.erase(i);
|
|
197
204
|
|
|
@@ -226,7 +233,7 @@ public:
|
|
|
226
233
|
assert(!seq[i].test(seq_id));
|
|
227
234
|
|
|
228
235
|
seq[i].set(seq_id);
|
|
229
|
-
|
|
236
|
+
seq_pos_inc(seq_id, pos[i]);
|
|
230
237
|
}
|
|
231
238
|
|
|
232
239
|
// return the sequence id of this cell
|
|
@@ -234,7 +241,7 @@ public:
|
|
|
234
241
|
llama_seq_id seq_get(uint32_t i) const {
|
|
235
242
|
assert(seq[i].count() == 1);
|
|
236
243
|
|
|
237
|
-
for (int s = 0; s <
|
|
244
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
238
245
|
if (seq[i].test(s)) {
|
|
239
246
|
return s;
|
|
240
247
|
}
|
|
@@ -247,26 +254,30 @@ public:
|
|
|
247
254
|
// return -1 if the sequence is not present
|
|
248
255
|
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
|
249
256
|
assert(seq_id >= 0);
|
|
250
|
-
assert(seq_id <
|
|
257
|
+
assert(seq_id < LLAMA_MAX_SEQ);
|
|
251
258
|
|
|
252
259
|
if (seq_pos[seq_id].empty()) {
|
|
253
260
|
return -1;
|
|
254
261
|
}
|
|
255
262
|
|
|
256
|
-
|
|
263
|
+
assert(seq_pos[seq_id].begin()->second > 0);
|
|
264
|
+
|
|
265
|
+
return seq_pos[seq_id].begin()->first;
|
|
257
266
|
}
|
|
258
267
|
|
|
259
268
|
// the maximum position of sequence seq_id currently present in any of the cells
|
|
260
269
|
// return -1 if the sequence is not present
|
|
261
270
|
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
|
262
271
|
assert(seq_id >= 0);
|
|
263
|
-
assert(seq_id <
|
|
272
|
+
assert(seq_id < LLAMA_MAX_SEQ);
|
|
264
273
|
|
|
265
274
|
if (seq_pos[seq_id].empty()) {
|
|
266
275
|
return -1;
|
|
267
276
|
}
|
|
268
277
|
|
|
269
|
-
|
|
278
|
+
assert(seq_pos[seq_id].rbegin()->second > 0);
|
|
279
|
+
|
|
280
|
+
return seq_pos[seq_id].rbegin()->first;
|
|
270
281
|
}
|
|
271
282
|
|
|
272
283
|
// note: call only if the cell is not empty
|
|
@@ -317,21 +328,20 @@ public:
|
|
|
317
328
|
pos[i] += d;
|
|
318
329
|
shift[i] += d;
|
|
319
330
|
|
|
320
|
-
seq_pos_add(i);
|
|
321
|
-
|
|
322
331
|
has_shift = true;
|
|
323
332
|
|
|
324
333
|
if (pos[i] < 0) {
|
|
325
|
-
seq_pos_rm(i);
|
|
326
|
-
|
|
327
334
|
seq[i].reset();
|
|
328
335
|
pos[i] = -1;
|
|
336
|
+
shift[i] = 0;
|
|
329
337
|
|
|
330
338
|
used.erase(i);
|
|
331
339
|
|
|
332
340
|
return true;
|
|
333
341
|
}
|
|
334
342
|
|
|
343
|
+
seq_pos_add(i);
|
|
344
|
+
|
|
335
345
|
return false;
|
|
336
346
|
}
|
|
337
347
|
|
|
@@ -379,31 +389,50 @@ private:
|
|
|
379
389
|
//
|
|
380
390
|
std::vector<llama_pos> shift;
|
|
381
391
|
|
|
382
|
-
using
|
|
392
|
+
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
|
383
393
|
|
|
384
394
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
|
385
|
-
std::vector<
|
|
395
|
+
std::vector<seq_set_t> seq;
|
|
386
396
|
|
|
387
|
-
// the set seq_pos[s] tells us
|
|
397
|
+
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
|
398
|
+
// if the position p is not present, seq_pos[s][p] is not set
|
|
388
399
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
|
389
|
-
|
|
400
|
+
//
|
|
401
|
+
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
|
402
|
+
// - during performing a cache reuse via (rm + add)
|
|
403
|
+
// - some vision models have input embeddings with repeating positions
|
|
404
|
+
//
|
|
405
|
+
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
|
390
406
|
|
|
391
407
|
// helper functions for updating `seq_pos`, once cell at a time:
|
|
392
408
|
|
|
409
|
+
void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
|
410
|
+
auto it = seq_pos[s].find(p);
|
|
411
|
+
assert(it != seq_pos[s].end());
|
|
412
|
+
|
|
413
|
+
if (--it->second == 0) {
|
|
414
|
+
seq_pos[s].erase(it);
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
|
419
|
+
seq_pos[s][p]++;
|
|
420
|
+
}
|
|
421
|
+
|
|
393
422
|
// remove cell i
|
|
394
423
|
void seq_pos_rm(uint32_t i) {
|
|
395
|
-
for (int s = 0; s <
|
|
424
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
396
425
|
if (seq[i].test(s)) {
|
|
397
|
-
|
|
426
|
+
seq_pos_dec(s, pos[i]);
|
|
398
427
|
}
|
|
399
428
|
}
|
|
400
429
|
}
|
|
401
430
|
|
|
402
431
|
// add cell i
|
|
403
432
|
void seq_pos_add(uint32_t i) {
|
|
404
|
-
for (int s = 0; s <
|
|
433
|
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
405
434
|
if (seq[i].test(s)) {
|
|
406
|
-
|
|
435
|
+
seq_pos_inc(s, pos[i]);
|
|
407
436
|
}
|
|
408
437
|
}
|
|
409
438
|
}
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
#include "llama-memory-hybrid.h"
|
|
2
|
+
|
|
3
|
+
#include "llama-impl.h"
|
|
4
|
+
#include "llama-model.h"
|
|
5
|
+
#include "llama-context.h"
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// llama_memory_hybrid
|
|
9
|
+
//
|
|
10
|
+
|
|
11
|
+
llama_memory_hybrid::llama_memory_hybrid(
|
|
12
|
+
const llama_model & model,
|
|
13
|
+
/* attn */
|
|
14
|
+
ggml_type type_k,
|
|
15
|
+
ggml_type type_v,
|
|
16
|
+
bool v_trans,
|
|
17
|
+
uint32_t kv_size,
|
|
18
|
+
uint32_t n_pad,
|
|
19
|
+
uint32_t n_swa,
|
|
20
|
+
llama_swa_type swa_type,
|
|
21
|
+
/* recurrent */
|
|
22
|
+
ggml_type type_r,
|
|
23
|
+
ggml_type type_s,
|
|
24
|
+
uint32_t rs_size,
|
|
25
|
+
/* common */
|
|
26
|
+
uint32_t n_seq_max,
|
|
27
|
+
bool offload,
|
|
28
|
+
/* layer filters */
|
|
29
|
+
layer_filter_cb && filter_attn,
|
|
30
|
+
layer_filter_cb && filter_recr) :
|
|
31
|
+
hparams(model.hparams),
|
|
32
|
+
mem_attn(new llama_kv_cache_unified(
|
|
33
|
+
model,
|
|
34
|
+
filter_attn == nullptr ?
|
|
35
|
+
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
|
36
|
+
: filter_attn,
|
|
37
|
+
type_k,
|
|
38
|
+
type_v,
|
|
39
|
+
v_trans,
|
|
40
|
+
offload,
|
|
41
|
+
kv_size,
|
|
42
|
+
n_seq_max,
|
|
43
|
+
n_pad,
|
|
44
|
+
n_swa,
|
|
45
|
+
swa_type
|
|
46
|
+
)),
|
|
47
|
+
mem_recr(new llama_memory_recurrent(
|
|
48
|
+
model,
|
|
49
|
+
filter_recr == nullptr ?
|
|
50
|
+
[&](int32_t il) { return hparams.is_recurrent(il); }
|
|
51
|
+
: filter_recr,
|
|
52
|
+
type_r,
|
|
53
|
+
type_s,
|
|
54
|
+
offload,
|
|
55
|
+
rs_size,
|
|
56
|
+
n_seq_max
|
|
57
|
+
)) {}
|
|
58
|
+
|
|
59
|
+
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
60
|
+
do {
|
|
61
|
+
balloc.split_reset();
|
|
62
|
+
|
|
63
|
+
// follow the recurrent pattern for creating the ubatch splits
|
|
64
|
+
std::vector<llama_ubatch> ubatches;
|
|
65
|
+
|
|
66
|
+
while (true) {
|
|
67
|
+
llama_ubatch ubatch;
|
|
68
|
+
|
|
69
|
+
if (embd_all) {
|
|
70
|
+
// if all tokens are output, split by sequence
|
|
71
|
+
ubatch = balloc.split_seq(n_ubatch);
|
|
72
|
+
} else {
|
|
73
|
+
ubatch = balloc.split_equal(n_ubatch);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
if (ubatch.n_tokens == 0) {
|
|
77
|
+
break;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
// prepare the recurrent batches first
|
|
84
|
+
if (!mem_recr->prepare(ubatches)) {
|
|
85
|
+
// TODO: will the recurrent cache be in an undefined context at this point?
|
|
86
|
+
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
|
87
|
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// prepare the attention cache
|
|
91
|
+
auto heads_attn = mem_attn->prepare(ubatches);
|
|
92
|
+
if (heads_attn.empty()) {
|
|
93
|
+
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
|
94
|
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
return std::make_unique<llama_memory_hybrid_context>(
|
|
98
|
+
this, std::move(heads_attn), std::move(ubatches));
|
|
99
|
+
} while(false);
|
|
100
|
+
|
|
101
|
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
llama_memory_context_ptr llama_memory_hybrid::init_full() {
|
|
105
|
+
return std::make_unique<llama_memory_hybrid_context>(this);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
|
109
|
+
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
bool llama_memory_hybrid::get_can_shift() const {
|
|
113
|
+
// Shifting is trivially supported for recurrent
|
|
114
|
+
return mem_attn->get_can_shift();
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
void llama_memory_hybrid::clear(bool data) {
|
|
118
|
+
mem_attn->clear(data);
|
|
119
|
+
mem_recr->clear(data);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
123
|
+
// Try removing from the recurrent cache first since it may fail. If it does
|
|
124
|
+
// fail, the cache will not have been mutated.
|
|
125
|
+
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
|
126
|
+
return false;
|
|
127
|
+
}
|
|
128
|
+
return mem_attn->seq_rm(seq_id, p0, p1);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
132
|
+
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
133
|
+
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
|
|
137
|
+
mem_attn->seq_keep(seq_id);
|
|
138
|
+
mem_recr->seq_keep(seq_id);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
142
|
+
mem_attn->seq_add(seq_id, p0, p1, shift);
|
|
143
|
+
mem_recr->seq_add(seq_id, p0, p1, shift);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
147
|
+
mem_attn->seq_div(seq_id, p0, p1, d);
|
|
148
|
+
mem_recr->seq_div(seq_id, p0, p1, d);
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
|
|
152
|
+
// the min of the total cache is the max of the two caches' min values
|
|
153
|
+
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
|
157
|
+
// the max of the total cache is the min of the two caches' max values
|
|
158
|
+
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
162
|
+
mem_attn->state_write(io, seq_id);
|
|
163
|
+
mem_recr->state_write(io, seq_id);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
167
|
+
mem_attn->state_read(io, seq_id);
|
|
168
|
+
mem_recr->state_read(io, seq_id);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
|
172
|
+
return mem_attn.get();
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
|
176
|
+
return mem_recr.get();
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
|
|
180
|
+
|
|
181
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
|
|
182
|
+
ctx_attn(mem->get_mem_attn()->init_full()),
|
|
183
|
+
ctx_recr(mem->get_mem_recr()->init_full()),
|
|
184
|
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|
188
|
+
llama_memory_hybrid * mem,
|
|
189
|
+
llama_context * lctx,
|
|
190
|
+
bool optimize) :
|
|
191
|
+
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
|
192
|
+
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
|
193
|
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|
197
|
+
llama_memory_hybrid * mem,
|
|
198
|
+
std::vector<uint32_t> heads_attn,
|
|
199
|
+
std::vector<llama_ubatch> ubatches) :
|
|
200
|
+
ubatches(std::move(ubatches)),
|
|
201
|
+
// note: here we copy the ubatches. not sure if this is ideal
|
|
202
|
+
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
|
203
|
+
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
|
204
|
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
bool llama_memory_hybrid_context::next() {
|
|
208
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
209
|
+
|
|
210
|
+
ctx_attn->next();
|
|
211
|
+
ctx_recr->next();
|
|
212
|
+
|
|
213
|
+
if (++i_next >= ubatches.size()) {
|
|
214
|
+
return false;
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
return true;
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
bool llama_memory_hybrid_context::apply() {
|
|
221
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
222
|
+
|
|
223
|
+
bool res = true;
|
|
224
|
+
|
|
225
|
+
res = res & ctx_attn->apply();
|
|
226
|
+
res = res & ctx_recr->apply();
|
|
227
|
+
|
|
228
|
+
return res;
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
llama_memory_status llama_memory_hybrid_context::get_status() const {
|
|
232
|
+
return status;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
|
236
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
237
|
+
return ubatches[i_next];
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
|
241
|
+
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
|
245
|
+
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
|
246
|
+
}
|