@novastera-oss/llamarn 0.2.6 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +141 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +58 -24
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +37 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +53 -40
- package/cpp/llama.cpp/common/common.h +6 -2
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
- package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +141 -38
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
- package/cpp/llama.cpp/src/llama-arch.h +25 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
- package/cpp/llama.cpp/src/llama-batch.h +110 -57
- package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +360 -266
- package/cpp/llama.cpp/src/llama-context.h +27 -23
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
- package/cpp/llama.cpp/src/llama-graph.h +126 -58
- package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
- package/cpp/llama.cpp/src/llama-hparams.h +16 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
- package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +73 -36
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
- package/cpp/llama.cpp/src/llama-model.h +26 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
- package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +6 -2
- package/ios/include/llama.h +141 -38
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
#include "llama-
|
|
1
|
+
#include "llama-memory-recurrent.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
|
+
#include "llama-io.h"
|
|
4
5
|
#include "llama-batch.h"
|
|
5
6
|
#include "llama-model.h"
|
|
6
7
|
|
|
@@ -11,27 +12,28 @@
|
|
|
11
12
|
#include <stdexcept>
|
|
12
13
|
|
|
13
14
|
//
|
|
14
|
-
//
|
|
15
|
+
// llama_memory_recurrent
|
|
15
16
|
//
|
|
16
17
|
|
|
17
|
-
|
|
18
|
-
const llama_model &
|
|
19
|
-
|
|
20
|
-
ggml_type
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
uint32_t
|
|
18
|
+
llama_memory_recurrent::llama_memory_recurrent(
|
|
19
|
+
const llama_model & model,
|
|
20
|
+
layer_filter_cb && filter,
|
|
21
|
+
ggml_type type_r,
|
|
22
|
+
ggml_type type_s,
|
|
23
|
+
bool offload,
|
|
24
|
+
uint32_t mem_size,
|
|
25
|
+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
|
24
26
|
const int32_t n_layer = hparams.n_layer;
|
|
25
27
|
|
|
26
|
-
LLAMA_LOG_INFO("%s:
|
|
27
|
-
__func__,
|
|
28
|
+
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
|
|
29
|
+
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
|
|
28
30
|
|
|
29
31
|
head = 0;
|
|
30
|
-
size =
|
|
32
|
+
size = mem_size;
|
|
31
33
|
used = 0;
|
|
32
34
|
|
|
33
35
|
cells.clear();
|
|
34
|
-
cells.resize(
|
|
36
|
+
cells.resize(mem_size);
|
|
35
37
|
|
|
36
38
|
// create a context for each buffer type
|
|
37
39
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
@@ -58,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
|
58
60
|
return it->second;
|
|
59
61
|
};
|
|
60
62
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
+
r_l.resize(n_layer);
|
|
64
|
+
s_l.resize(n_layer);
|
|
63
65
|
|
|
64
66
|
for (int i = 0; i < n_layer; i++) {
|
|
65
|
-
|
|
66
|
-
|
|
67
|
+
if (filter && !filter(i)) {
|
|
68
|
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
|
|
69
|
+
continue;
|
|
70
|
+
}
|
|
67
71
|
|
|
68
72
|
const char * dev_name = "CPU";
|
|
69
73
|
|
|
@@ -83,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
|
83
87
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
|
84
88
|
}
|
|
85
89
|
|
|
86
|
-
ggml_tensor *
|
|
87
|
-
ggml_tensor *
|
|
88
|
-
ggml_format_name(
|
|
89
|
-
ggml_format_name(
|
|
90
|
-
|
|
91
|
-
|
|
90
|
+
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
|
91
|
+
ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
|
|
92
|
+
ggml_format_name(r, "cache_r_l%d", i);
|
|
93
|
+
ggml_format_name(s, "cache_s_l%d", i);
|
|
94
|
+
r_l[i] = r;
|
|
95
|
+
s_l[i] = s;
|
|
92
96
|
}
|
|
93
97
|
|
|
94
98
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
@@ -106,32 +110,35 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
|
106
110
|
}
|
|
107
111
|
|
|
108
112
|
{
|
|
109
|
-
const size_t
|
|
110
|
-
const size_t
|
|
113
|
+
const size_t memory_size_r = size_r_bytes();
|
|
114
|
+
const size_t memory_size_s = size_s_bytes();
|
|
111
115
|
|
|
112
|
-
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB,
|
|
113
|
-
(float)(
|
|
114
|
-
ggml_type_name(
|
|
115
|
-
ggml_type_name(
|
|
116
|
+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
|
117
|
+
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
|
118
|
+
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
|
119
|
+
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
|
116
120
|
}
|
|
117
121
|
}
|
|
118
122
|
|
|
119
|
-
void
|
|
123
|
+
void llama_memory_recurrent::clear(bool data) {
|
|
120
124
|
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
|
121
125
|
cells[i].pos = -1;
|
|
122
126
|
cells[i].seq_id.clear();
|
|
123
127
|
cells[i].src = -1;
|
|
124
128
|
cells[i].tail = -1;
|
|
125
129
|
}
|
|
130
|
+
|
|
126
131
|
head = 0;
|
|
127
132
|
used = 0;
|
|
128
133
|
|
|
129
|
-
|
|
130
|
-
|
|
134
|
+
if (data) {
|
|
135
|
+
for (auto & buf : bufs) {
|
|
136
|
+
ggml_backend_buffer_clear(buf.get(), 0);
|
|
137
|
+
}
|
|
131
138
|
}
|
|
132
139
|
}
|
|
133
140
|
|
|
134
|
-
bool
|
|
141
|
+
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
135
142
|
uint32_t new_head = size;
|
|
136
143
|
|
|
137
144
|
if (p0 < 0) {
|
|
@@ -150,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|
|
150
157
|
if (0 <= seq_id) {
|
|
151
158
|
int32_t & tail_id = cells[seq_id].tail;
|
|
152
159
|
if (tail_id >= 0) {
|
|
153
|
-
const
|
|
160
|
+
const auto & cell = cells[tail_id];
|
|
154
161
|
// partial intersection is invalid
|
|
155
162
|
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
|
156
163
|
return false;
|
|
@@ -198,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|
|
198
205
|
return true;
|
|
199
206
|
}
|
|
200
207
|
|
|
201
|
-
void
|
|
208
|
+
void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
202
209
|
if (seq_id_src == seq_id_dst) {
|
|
203
210
|
return;
|
|
204
211
|
}
|
|
@@ -212,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
|
212
219
|
}
|
|
213
220
|
|
|
214
221
|
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
|
215
|
-
|
|
216
|
-
|
|
222
|
+
auto & tail_src = cells[seq_id_src];
|
|
223
|
+
auto & tail_dst = cells[seq_id_dst];
|
|
217
224
|
if (tail_dst.tail >= 0) {
|
|
218
225
|
// clear destination seq_id if it wasn't empty
|
|
219
|
-
|
|
226
|
+
auto & cell_dst = cells[tail_dst.tail];
|
|
220
227
|
|
|
221
228
|
cell_dst.seq_id.erase(seq_id_dst);
|
|
222
229
|
tail_dst.tail = -1;
|
|
@@ -227,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
|
227
234
|
}
|
|
228
235
|
}
|
|
229
236
|
if (tail_src.tail >= 0) {
|
|
230
|
-
|
|
237
|
+
auto & cell_src = cells[tail_src.tail];
|
|
231
238
|
|
|
232
239
|
cell_src.seq_id.insert(seq_id_dst);
|
|
233
240
|
tail_dst.tail = tail_src.tail;
|
|
@@ -235,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
|
235
242
|
}
|
|
236
243
|
}
|
|
237
244
|
|
|
238
|
-
void
|
|
245
|
+
void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
|
|
239
246
|
uint32_t new_head = size;
|
|
240
247
|
|
|
241
248
|
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -267,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
|
|
267
274
|
}
|
|
268
275
|
}
|
|
269
276
|
|
|
270
|
-
void
|
|
277
|
+
void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
271
278
|
if (shift == 0) {
|
|
272
279
|
return;
|
|
273
280
|
}
|
|
@@ -289,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
289
296
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
|
290
297
|
const int32_t tail_id = cells[seq_id].tail;
|
|
291
298
|
if (tail_id >= 0) {
|
|
292
|
-
|
|
299
|
+
auto & cell = cells[tail_id];
|
|
293
300
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
|
294
301
|
cell.pos += shift;
|
|
295
302
|
}
|
|
@@ -297,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
297
304
|
}
|
|
298
305
|
}
|
|
299
306
|
|
|
300
|
-
void
|
|
307
|
+
void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
301
308
|
if (d == 1) {
|
|
302
309
|
return;
|
|
303
310
|
}
|
|
@@ -319,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
319
326
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
|
320
327
|
const int32_t tail_id = cells[seq_id].tail;
|
|
321
328
|
if (tail_id >= 0) {
|
|
322
|
-
|
|
329
|
+
auto & cell = cells[tail_id];
|
|
323
330
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
|
324
331
|
cell.pos /= d;
|
|
325
332
|
}
|
|
@@ -327,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
|
327
334
|
}
|
|
328
335
|
}
|
|
329
336
|
|
|
330
|
-
llama_pos
|
|
337
|
+
llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
|
331
338
|
llama_pos result = std::numeric_limits<llama_pos>::max();
|
|
332
339
|
|
|
333
340
|
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -343,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
|
|
343
350
|
return result;
|
|
344
351
|
}
|
|
345
352
|
|
|
346
|
-
llama_pos
|
|
353
|
+
llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
347
354
|
llama_pos result = -1;
|
|
348
355
|
|
|
349
356
|
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -355,38 +362,50 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
355
362
|
return result;
|
|
356
363
|
}
|
|
357
364
|
|
|
358
|
-
|
|
359
|
-
|
|
365
|
+
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
366
|
+
do {
|
|
367
|
+
balloc.split_reset();
|
|
360
368
|
|
|
361
|
-
|
|
369
|
+
std::vector<llama_ubatch> ubatches;
|
|
370
|
+
while (true) {
|
|
371
|
+
llama_ubatch ubatch;
|
|
362
372
|
|
|
363
|
-
|
|
373
|
+
if (embd_all) {
|
|
374
|
+
// if all tokens are output, split by sequence
|
|
375
|
+
ubatch = balloc.split_seq(n_ubatch);
|
|
376
|
+
} else {
|
|
377
|
+
ubatch = balloc.split_equal(n_ubatch);
|
|
378
|
+
}
|
|
364
379
|
|
|
365
|
-
|
|
366
|
-
|
|
380
|
+
if (ubatch.n_tokens == 0) {
|
|
381
|
+
break;
|
|
382
|
+
}
|
|
367
383
|
|
|
368
|
-
|
|
369
|
-
// Pooled embeddings cannot be split across ubatches (yet)
|
|
370
|
-
ubatch = sbatch.split_seq(n_ubatch);
|
|
371
|
-
} else {
|
|
372
|
-
ubatch = sbatch.split_equal(n_ubatch);
|
|
384
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
373
385
|
}
|
|
374
386
|
|
|
375
|
-
ubatches
|
|
376
|
-
|
|
387
|
+
if (!prepare(ubatches)) {
|
|
388
|
+
break;
|
|
389
|
+
}
|
|
377
390
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
391
|
+
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
|
392
|
+
} while (false);
|
|
393
|
+
|
|
394
|
+
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
395
|
+
}
|
|
381
396
|
|
|
382
|
-
|
|
397
|
+
llama_memory_context_ptr llama_memory_recurrent::init_full() {
|
|
398
|
+
return std::make_unique<llama_memory_recurrent_context>(this);
|
|
383
399
|
}
|
|
384
400
|
|
|
385
|
-
|
|
386
|
-
|
|
401
|
+
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
|
402
|
+
GGML_UNUSED(lctx);
|
|
403
|
+
GGML_UNUSED(optimize);
|
|
404
|
+
|
|
405
|
+
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
|
387
406
|
}
|
|
388
407
|
|
|
389
|
-
bool
|
|
408
|
+
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
390
409
|
// simply remember the full state because it is very small for this type of cache
|
|
391
410
|
// TODO: optimize
|
|
392
411
|
auto org_cells = cells;
|
|
@@ -395,21 +414,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
|
395
414
|
|
|
396
415
|
bool success = true;
|
|
397
416
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
|
|
405
|
-
//
|
|
406
|
-
GGML_UNUSED(ubatches);
|
|
407
|
-
//for (const auto & ubatch : ubatches) {
|
|
408
|
-
// if (!find_slot(ubatch)) {
|
|
409
|
-
// success = false;
|
|
410
|
-
// break;
|
|
411
|
-
// }
|
|
412
|
-
//}
|
|
417
|
+
for (const auto & ubatch : ubatches) {
|
|
418
|
+
if (!find_slot(ubatch)) {
|
|
419
|
+
success = false;
|
|
420
|
+
break;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
413
423
|
|
|
414
424
|
// restore the original state
|
|
415
425
|
cells = std::move(org_cells);
|
|
@@ -419,26 +429,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
|
419
429
|
return success;
|
|
420
430
|
}
|
|
421
431
|
|
|
422
|
-
bool
|
|
423
|
-
GGML_UNUSED(lctx);
|
|
424
|
-
// noop
|
|
425
|
-
return false;
|
|
426
|
-
}
|
|
427
|
-
|
|
428
|
-
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
|
429
|
-
GGML_UNUSED(thold);
|
|
430
|
-
// noop
|
|
431
|
-
}
|
|
432
|
-
|
|
433
|
-
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
434
|
-
const uint32_t n_tokens = ubatch.n_tokens;
|
|
435
|
-
const uint32_t n_seqs = ubatch.n_seqs;
|
|
436
|
-
|
|
432
|
+
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
437
433
|
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
434
|
+
const uint32_t n_seqs = ubatch.n_seqs;
|
|
438
435
|
|
|
439
436
|
// if we have enough unused cells before the current head ->
|
|
440
437
|
// better to start searching from the beginning of the cache, hoping to fill it
|
|
441
|
-
if (head > used + 2*
|
|
438
|
+
if (head > used + 2*n_seqs) {
|
|
442
439
|
head = 0;
|
|
443
440
|
}
|
|
444
441
|
|
|
@@ -454,9 +451,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
454
451
|
|
|
455
452
|
// everything should fit if all seq_ids are smaller than the max
|
|
456
453
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
457
|
-
const uint32_t
|
|
454
|
+
const uint32_t i = s*n_seq_tokens; // first token of sequence set s
|
|
455
|
+
const uint32_t n_seq_id = ubatch.n_seq_id[i];
|
|
456
|
+
|
|
458
457
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
|
459
|
-
const llama_seq_id seq_id = ubatch.seq_id[
|
|
458
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
|
460
459
|
|
|
461
460
|
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
|
462
461
|
// too big seq_id
|
|
@@ -465,9 +464,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
465
464
|
return false;
|
|
466
465
|
}
|
|
467
466
|
if (j > 0) {
|
|
468
|
-
|
|
467
|
+
auto & seq = cells[seq_id];
|
|
469
468
|
if (seq.tail >= 0) {
|
|
470
|
-
|
|
469
|
+
auto & cell = cells[seq.tail];
|
|
471
470
|
// clear cells from seq_ids that become shared
|
|
472
471
|
// (should not normally happen, but let's handle it anyway)
|
|
473
472
|
cell.seq_id.erase(seq_id);
|
|
@@ -487,7 +486,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
487
486
|
std::vector<int32_t> tails_verif;
|
|
488
487
|
tails_verif.assign(size, -1);
|
|
489
488
|
for (uint32_t i = 0; i < size; ++i) {
|
|
490
|
-
|
|
489
|
+
auto & cell = cells[i];
|
|
491
490
|
for (llama_seq_id seq_id : cell.seq_id) {
|
|
492
491
|
if (tails_verif[seq_id] != -1) {
|
|
493
492
|
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
|
@@ -508,42 +507,43 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
508
507
|
|
|
509
508
|
for (uint32_t i = 0; i < size; ++i) {
|
|
510
509
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
|
511
|
-
|
|
510
|
+
auto & cell = cells[next_empty_cell];
|
|
512
511
|
if (cell.is_empty()) { break; }
|
|
513
512
|
next_empty_cell += 1;
|
|
514
513
|
}
|
|
515
514
|
|
|
516
515
|
// find usable cell range
|
|
517
516
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
518
|
-
const
|
|
519
|
-
|
|
517
|
+
const uint32_t i = s*n_seq_tokens;
|
|
518
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
519
|
+
auto & seq_meta = cells[seq_id];
|
|
520
520
|
bool has_cell = false;
|
|
521
521
|
if (seq_meta.tail >= 0) {
|
|
522
|
-
|
|
522
|
+
auto & cell = cells[seq_meta.tail];
|
|
523
523
|
GGML_ASSERT(cell.has_seq_id(seq_id));
|
|
524
524
|
// does this seq_id "own" the cell?
|
|
525
525
|
if (cell.seq_id.size() == 1) { has_cell = true; }
|
|
526
526
|
}
|
|
527
527
|
if (!has_cell) {
|
|
528
|
-
|
|
528
|
+
auto & empty_cell = cells[next_empty_cell];
|
|
529
529
|
GGML_ASSERT(empty_cell.is_empty());
|
|
530
530
|
// copy old tail into the empty cell
|
|
531
531
|
if (seq_meta.tail >= 0) {
|
|
532
|
-
|
|
532
|
+
auto & orig_cell = cells[seq_meta.tail];
|
|
533
533
|
empty_cell.pos = orig_cell.pos;
|
|
534
534
|
empty_cell.src = orig_cell.src;
|
|
535
535
|
orig_cell.seq_id.erase(seq_id);
|
|
536
536
|
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
|
537
|
+
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
|
|
537
538
|
}
|
|
538
539
|
seq_meta.tail = next_empty_cell;
|
|
539
540
|
// find next empty cell
|
|
540
541
|
if (s + 1 < n_seqs) {
|
|
541
|
-
|
|
542
|
-
|
|
542
|
+
for (uint32_t j = 0; j < size; ++j) {
|
|
543
|
+
next_empty_cell += 1;
|
|
543
544
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
|
544
|
-
|
|
545
|
+
auto & cell = cells[next_empty_cell];
|
|
545
546
|
if (cell.is_empty()) { break; }
|
|
546
|
-
next_empty_cell += 1;
|
|
547
547
|
}
|
|
548
548
|
}
|
|
549
549
|
}
|
|
@@ -553,102 +553,99 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
553
553
|
|
|
554
554
|
// gather and re-order
|
|
555
555
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
556
|
-
|
|
557
|
-
int32_t
|
|
556
|
+
const uint32_t i = s*n_seq_tokens;
|
|
557
|
+
const int32_t dst_id = s + min;
|
|
558
|
+
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
|
|
558
559
|
if (dst_id != src_id) {
|
|
559
|
-
|
|
560
|
-
|
|
560
|
+
auto & dst_cell = cells[dst_id];
|
|
561
|
+
auto & src_cell = cells[src_id];
|
|
561
562
|
|
|
562
563
|
std::swap(dst_cell.pos, src_cell.pos);
|
|
563
564
|
std::swap(dst_cell.src, src_cell.src);
|
|
564
565
|
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
|
565
566
|
|
|
566
|
-
// swap tails
|
|
567
|
-
for (
|
|
568
|
-
cells[
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
567
|
+
// swap tails
|
|
568
|
+
for (uint32_t j = 0; j < size; ++j) {
|
|
569
|
+
int32_t & tail = cells[j].tail;
|
|
570
|
+
if (tail == src_id) {
|
|
571
|
+
tail = dst_id;
|
|
572
|
+
} else if (tail == dst_id) {
|
|
573
|
+
tail = src_id;
|
|
574
|
+
}
|
|
572
575
|
}
|
|
573
576
|
}
|
|
574
577
|
}
|
|
575
578
|
|
|
576
579
|
// update the pos of the used seqs
|
|
577
580
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
578
|
-
const
|
|
579
|
-
|
|
580
|
-
|
|
581
|
+
const uint32_t i = s*n_seq_tokens;
|
|
582
|
+
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
|
|
583
|
+
const int32_t cell_id = s + min;
|
|
584
|
+
auto & cell = cells[cell_id];
|
|
581
585
|
|
|
582
586
|
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
|
583
587
|
// What should happen when the pos backtracks or skips a value?
|
|
584
588
|
// Clearing the state mid-batch would require special-casing which isn't done.
|
|
585
589
|
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
|
586
|
-
__func__, last_pos, cell.pos, ubatch.seq_id[
|
|
590
|
+
__func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
|
|
587
591
|
}
|
|
588
592
|
cell.pos = last_pos;
|
|
589
593
|
cell.seq_id.clear();
|
|
590
|
-
for (int32_t j = 0; j < ubatch.n_seq_id[
|
|
591
|
-
const llama_seq_id seq_id = ubatch.seq_id[
|
|
594
|
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
|
|
595
|
+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
|
592
596
|
cell.seq_id.insert(seq_id);
|
|
593
597
|
cells[seq_id].tail = cell_id;
|
|
594
598
|
}
|
|
595
599
|
}
|
|
596
600
|
|
|
601
|
+
// Find first cell without src refs, to use as the zero-ed state
|
|
602
|
+
{
|
|
603
|
+
// TODO: bake-in src refcounts in the cell metadata
|
|
604
|
+
std::vector<int32_t> refcounts(size, 0);
|
|
605
|
+
for (size_t i = 0; i < size; ++i) {
|
|
606
|
+
const int32_t src = cells[i].src;
|
|
607
|
+
if (src >= 0) {
|
|
608
|
+
refcounts[src] += 1;
|
|
609
|
+
}
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
rs_z = -1;
|
|
613
|
+
for (int i = min; i <= max; ++i) {
|
|
614
|
+
if (refcounts[i] == 0) {
|
|
615
|
+
rs_z = i;
|
|
616
|
+
break;
|
|
617
|
+
}
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
for (int i = min; i <= max; ++i) {
|
|
621
|
+
if (cells[i].src < 0) {
|
|
622
|
+
GGML_ASSERT(rs_z >= 0);
|
|
623
|
+
cells[i].src0 = rs_z;
|
|
624
|
+
} else {
|
|
625
|
+
// Stage the source ids for all used cells to allow correct seq_* behavior
|
|
626
|
+
// and still make these values available when setting the inputs
|
|
627
|
+
cells[i].src0 = cells[i].src;
|
|
628
|
+
}
|
|
629
|
+
cells[i].src = i; // avoid moving or clearing twice
|
|
630
|
+
}
|
|
631
|
+
}
|
|
632
|
+
|
|
597
633
|
// allow getting the range of used cells, from head to head + n
|
|
598
634
|
head = min;
|
|
599
635
|
n = max - min + 1;
|
|
600
636
|
used = std::count_if(cells.begin(), cells.end(),
|
|
601
|
-
[](const
|
|
637
|
+
[](const mem_cell & cell){ return !cell.is_empty(); });
|
|
602
638
|
|
|
603
639
|
// sanity check
|
|
604
640
|
return n >= n_seqs;
|
|
605
641
|
}
|
|
606
642
|
|
|
607
|
-
bool
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
|
|
612
|
-
const uint32_t cell_id = i + head;
|
|
613
|
-
|
|
614
|
-
//////////////////////////////////////////////
|
|
615
|
-
// TODO: this should not mutate the KV cache !
|
|
616
|
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
|
617
|
-
|
|
618
|
-
// prevent out-of-bound sources
|
|
619
|
-
if (cell.src < 0 || (uint32_t) cell.src >= size) {
|
|
620
|
-
cell.src = cell_id;
|
|
621
|
-
}
|
|
622
|
-
|
|
623
|
-
int32_t res = cell.src;
|
|
624
|
-
|
|
625
|
-
// TODO: do not mutate the KV cache
|
|
626
|
-
// ensure copy only happens once
|
|
627
|
-
if (cell.src != (int32_t) cell_id) {
|
|
628
|
-
cell.src = cell_id;
|
|
629
|
-
}
|
|
630
|
-
|
|
631
|
-
return res;
|
|
632
|
-
}
|
|
633
|
-
|
|
634
|
-
float llama_kv_cache_recurrent::s_mask(int i) const {
|
|
635
|
-
const uint32_t cell_id = i + head;
|
|
636
|
-
|
|
637
|
-
//////////////////////////////////////////////
|
|
638
|
-
// TODO: this should not mutate the KV cache !
|
|
639
|
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
|
640
|
-
|
|
641
|
-
float res = (float) (cell.src >= 0);
|
|
642
|
-
|
|
643
|
-
// only clear once
|
|
644
|
-
if (cell.src < 0) {
|
|
645
|
-
cell.src = cell_id;
|
|
646
|
-
}
|
|
647
|
-
|
|
648
|
-
return res;
|
|
643
|
+
bool llama_memory_recurrent::get_can_shift() const {
|
|
644
|
+
// shifting the pos is trivial for recurrent models
|
|
645
|
+
return true;
|
|
649
646
|
}
|
|
650
647
|
|
|
651
|
-
size_t
|
|
648
|
+
size_t llama_memory_recurrent::total_size() const {
|
|
652
649
|
size_t size = 0;
|
|
653
650
|
for (const auto & buf : bufs) {
|
|
654
651
|
size += ggml_backend_buffer_get_size(buf.get());
|
|
@@ -657,27 +654,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
|
|
|
657
654
|
return size;
|
|
658
655
|
}
|
|
659
656
|
|
|
660
|
-
size_t
|
|
661
|
-
size_t
|
|
657
|
+
size_t llama_memory_recurrent::size_r_bytes() const {
|
|
658
|
+
size_t size_r_bytes = 0;
|
|
662
659
|
|
|
663
|
-
for (const auto &
|
|
664
|
-
|
|
660
|
+
for (const auto & r : r_l) {
|
|
661
|
+
if (r != nullptr) {
|
|
662
|
+
size_r_bytes += ggml_nbytes(r);
|
|
663
|
+
}
|
|
665
664
|
}
|
|
666
665
|
|
|
667
|
-
return
|
|
666
|
+
return size_r_bytes;
|
|
668
667
|
}
|
|
669
668
|
|
|
670
|
-
size_t
|
|
671
|
-
size_t
|
|
669
|
+
size_t llama_memory_recurrent::size_s_bytes() const {
|
|
670
|
+
size_t size_s_bytes = 0;
|
|
672
671
|
|
|
673
|
-
for (const auto &
|
|
674
|
-
|
|
672
|
+
for (const auto & s : s_l) {
|
|
673
|
+
if (s != nullptr) {
|
|
674
|
+
size_s_bytes += ggml_nbytes(s);
|
|
675
|
+
}
|
|
675
676
|
}
|
|
676
677
|
|
|
677
|
-
return
|
|
678
|
+
return size_s_bytes;
|
|
678
679
|
}
|
|
679
680
|
|
|
680
|
-
void
|
|
681
|
+
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
681
682
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
|
682
683
|
uint32_t cell_count = 0;
|
|
683
684
|
|
|
@@ -715,7 +716,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
|
|
|
715
716
|
state_write_data(io, cell_ranges);
|
|
716
717
|
}
|
|
717
718
|
|
|
718
|
-
void
|
|
719
|
+
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
719
720
|
uint32_t cell_count;
|
|
720
721
|
io.read_to(&cell_count, sizeof(cell_count));
|
|
721
722
|
|
|
@@ -726,7 +727,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
|
726
727
|
|
|
727
728
|
if (!res) {
|
|
728
729
|
if (seq_id == -1) {
|
|
729
|
-
clear();
|
|
730
|
+
clear(true);
|
|
730
731
|
} else {
|
|
731
732
|
seq_rm(seq_id, -1, -1);
|
|
732
733
|
}
|
|
@@ -734,7 +735,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
|
734
735
|
}
|
|
735
736
|
}
|
|
736
737
|
|
|
737
|
-
void
|
|
738
|
+
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
|
738
739
|
for (const auto & range : cell_ranges) {
|
|
739
740
|
for (uint32_t i = range.first; i < range.second; ++i) {
|
|
740
741
|
const auto & cell = cells[i];
|
|
@@ -753,98 +754,93 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
|
|
|
753
754
|
}
|
|
754
755
|
}
|
|
755
756
|
|
|
756
|
-
void
|
|
757
|
-
const uint32_t
|
|
757
|
+
void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
|
758
|
+
const uint32_t s_trans = 0;
|
|
758
759
|
const uint32_t n_layer = hparams.n_layer;
|
|
759
760
|
|
|
760
|
-
io.write(&
|
|
761
|
-
io.write(&n_layer,
|
|
761
|
+
io.write(&s_trans, sizeof(s_trans));
|
|
762
|
+
io.write(&n_layer, sizeof(n_layer));
|
|
762
763
|
|
|
763
764
|
std::vector<uint8_t> tmp_buf;
|
|
764
765
|
|
|
765
766
|
// Iterate and write all the keys first, each row is a cell
|
|
766
767
|
// Get whole range at a time
|
|
767
768
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
768
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
769
769
|
|
|
770
770
|
// Write key type
|
|
771
|
-
const int32_t
|
|
772
|
-
io.write(&
|
|
771
|
+
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
|
772
|
+
io.write(&r_type_i, sizeof(r_type_i));
|
|
773
773
|
|
|
774
774
|
// Write row size of key
|
|
775
|
-
const uint64_t
|
|
776
|
-
io.write(&
|
|
775
|
+
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
|
776
|
+
io.write(&r_size_row, sizeof(r_size_row));
|
|
777
777
|
|
|
778
778
|
// Read each range of cells of k_size length each into tmp_buf and write out
|
|
779
779
|
for (const auto & range : cell_ranges) {
|
|
780
780
|
const size_t range_size = range.second - range.first;
|
|
781
|
-
const size_t buf_size = range_size *
|
|
782
|
-
io.write_tensor(
|
|
781
|
+
const size_t buf_size = range_size * r_size_row;
|
|
782
|
+
io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
|
|
783
783
|
}
|
|
784
784
|
}
|
|
785
785
|
|
|
786
|
-
if (!
|
|
786
|
+
if (!s_trans) {
|
|
787
787
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
788
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
789
788
|
|
|
790
789
|
// Write value type
|
|
791
|
-
const int32_t
|
|
792
|
-
io.write(&
|
|
790
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
791
|
+
io.write(&s_type_i, sizeof(s_type_i));
|
|
793
792
|
|
|
794
793
|
// Write row size of value
|
|
795
|
-
const uint64_t
|
|
796
|
-
io.write(&
|
|
794
|
+
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
|
795
|
+
io.write(&s_size_row, sizeof(s_size_row));
|
|
797
796
|
|
|
798
|
-
// Read each range of cells of
|
|
797
|
+
// Read each range of cells of s_size length each into tmp_buf and write out
|
|
799
798
|
for (const auto & range : cell_ranges) {
|
|
800
799
|
const size_t range_size = range.second - range.first;
|
|
801
|
-
const size_t buf_size = range_size *
|
|
802
|
-
io.write_tensor(
|
|
800
|
+
const size_t buf_size = range_size * s_size_row;
|
|
801
|
+
io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
|
|
803
802
|
}
|
|
804
803
|
}
|
|
805
804
|
} else {
|
|
806
805
|
// When v is transposed, we also need the element size and get the element ranges from each row
|
|
807
|
-
const uint32_t
|
|
806
|
+
const uint32_t mem_size = size;
|
|
808
807
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
809
|
-
const uint32_t
|
|
808
|
+
const uint32_t n_embd_s = hparams.n_embd_s();
|
|
810
809
|
|
|
811
810
|
// Write value type
|
|
812
|
-
const int32_t
|
|
813
|
-
io.write(&
|
|
811
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
812
|
+
io.write(&s_type_i, sizeof(s_type_i));
|
|
814
813
|
|
|
815
814
|
// Write element size
|
|
816
|
-
const uint32_t
|
|
817
|
-
io.write(&
|
|
815
|
+
const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
|
|
816
|
+
io.write(&s_size_el, sizeof(s_size_el));
|
|
818
817
|
|
|
819
818
|
// Write GQA embedding size
|
|
820
|
-
io.write(&
|
|
819
|
+
io.write(&n_embd_s, sizeof(n_embd_s));
|
|
821
820
|
|
|
822
821
|
// For each row, we get the element values of each cell
|
|
823
|
-
for (uint32_t j = 0; j <
|
|
822
|
+
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
|
824
823
|
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
|
825
824
|
for (const auto & range : cell_ranges) {
|
|
826
825
|
const size_t range_size = range.second - range.first;
|
|
827
|
-
const size_t src_offset = (range.first + j *
|
|
828
|
-
const size_t buf_size = range_size *
|
|
829
|
-
io.write_tensor(
|
|
826
|
+
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
|
|
827
|
+
const size_t buf_size = range_size * s_size_el;
|
|
828
|
+
io.write_tensor(s_l[il], src_offset, buf_size);
|
|
830
829
|
}
|
|
831
830
|
}
|
|
832
831
|
}
|
|
833
832
|
}
|
|
834
833
|
}
|
|
835
834
|
|
|
836
|
-
bool
|
|
835
|
+
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
|
837
836
|
if (dest_seq_id != -1) {
|
|
838
837
|
// single sequence
|
|
839
838
|
|
|
840
839
|
seq_rm(dest_seq_id, -1, -1);
|
|
841
840
|
|
|
842
|
-
|
|
843
|
-
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
|
841
|
+
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
|
844
842
|
|
|
845
|
-
|
|
846
|
-
batch.n_seq_tokens = cell_count;
|
|
847
|
-
batch.n_seqs = 1;
|
|
843
|
+
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
|
848
844
|
|
|
849
845
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
850
846
|
llama_pos pos;
|
|
@@ -858,12 +854,12 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
858
854
|
return false;
|
|
859
855
|
}
|
|
860
856
|
|
|
861
|
-
|
|
857
|
+
ubatch.pos[i] = pos;
|
|
862
858
|
}
|
|
863
|
-
|
|
864
|
-
|
|
859
|
+
ubatch.n_seq_id[0] = 1;
|
|
860
|
+
ubatch.seq_id[0] = &dest_seq_id;
|
|
865
861
|
|
|
866
|
-
if (!find_slot(
|
|
862
|
+
if (!find_slot(ubatch)) {
|
|
867
863
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
868
864
|
return false;
|
|
869
865
|
}
|
|
@@ -871,8 +867,8 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
871
867
|
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
|
872
868
|
// Assume that this is one contiguous block of cells
|
|
873
869
|
GGML_ASSERT(head + cell_count <= size);
|
|
874
|
-
GGML_ASSERT(cells[head].pos ==
|
|
875
|
-
GGML_ASSERT(cells[head + cell_count - 1].pos ==
|
|
870
|
+
GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
|
|
871
|
+
GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
|
|
876
872
|
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
|
877
873
|
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
|
878
874
|
} else {
|
|
@@ -883,10 +879,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
883
879
|
return false;
|
|
884
880
|
}
|
|
885
881
|
|
|
886
|
-
clear();
|
|
882
|
+
clear(true);
|
|
887
883
|
|
|
888
884
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
889
|
-
|
|
885
|
+
auto & cell = cells[i];
|
|
890
886
|
|
|
891
887
|
llama_pos pos;
|
|
892
888
|
uint32_t n_seq_id;
|
|
@@ -900,7 +896,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
900
896
|
llama_seq_id seq_id;
|
|
901
897
|
io.read_to(&seq_id, sizeof(seq_id));
|
|
902
898
|
|
|
903
|
-
// TODO:
|
|
899
|
+
// TODO: llama_memory_recurrent should have a notion of max sequences
|
|
904
900
|
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
|
905
901
|
if (seq_id < 0) {
|
|
906
902
|
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
|
@@ -932,10 +928,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
|
932
928
|
return true;
|
|
933
929
|
}
|
|
934
930
|
|
|
935
|
-
bool
|
|
936
|
-
uint32_t
|
|
931
|
+
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
|
932
|
+
uint32_t s_trans;
|
|
937
933
|
uint32_t n_layer;
|
|
938
|
-
io.read_to(&
|
|
934
|
+
io.read_to(&s_trans, sizeof(s_trans));
|
|
939
935
|
io.read_to(&n_layer, sizeof(n_layer));
|
|
940
936
|
|
|
941
937
|
if (n_layer != hparams.n_layer) {
|
|
@@ -946,102 +942,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
|
946
942
|
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
|
947
943
|
return false;
|
|
948
944
|
}
|
|
949
|
-
if (false != (bool)
|
|
950
|
-
LLAMA_LOG_ERROR("%s: incompatible
|
|
945
|
+
if (false != (bool) s_trans) {
|
|
946
|
+
LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
|
|
951
947
|
return false;
|
|
952
948
|
}
|
|
953
949
|
|
|
954
950
|
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
|
955
951
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
956
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
|
957
952
|
|
|
958
953
|
// Read type of key
|
|
959
|
-
int32_t
|
|
960
|
-
io.read_to(&
|
|
961
|
-
const int32_t
|
|
962
|
-
if (
|
|
963
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
954
|
+
int32_t r_type_i_ref;
|
|
955
|
+
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
|
956
|
+
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
|
957
|
+
if (r_type_i != r_type_i_ref) {
|
|
958
|
+
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
|
964
959
|
return false;
|
|
965
960
|
}
|
|
966
961
|
|
|
967
962
|
// Read row size of key
|
|
968
|
-
uint64_t
|
|
969
|
-
io.read_to(&
|
|
970
|
-
const size_t
|
|
971
|
-
if (
|
|
972
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
963
|
+
uint64_t r_size_row_ref;
|
|
964
|
+
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
|
965
|
+
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
|
966
|
+
if (r_size_row != r_size_row_ref) {
|
|
967
|
+
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
|
973
968
|
return false;
|
|
974
969
|
}
|
|
975
970
|
|
|
976
971
|
if (cell_count) {
|
|
977
972
|
// Read and set the keys for the whole cell range
|
|
978
|
-
ggml_backend_tensor_set(
|
|
973
|
+
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
|
979
974
|
}
|
|
980
975
|
}
|
|
981
976
|
|
|
982
|
-
if (!
|
|
977
|
+
if (!s_trans) {
|
|
983
978
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
984
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
|
985
979
|
|
|
986
980
|
// Read type of value
|
|
987
|
-
int32_t
|
|
988
|
-
io.read_to(&
|
|
989
|
-
const int32_t
|
|
990
|
-
if (
|
|
991
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
981
|
+
int32_t s_type_i_ref;
|
|
982
|
+
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
|
983
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
984
|
+
if (s_type_i != s_type_i_ref) {
|
|
985
|
+
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
|
992
986
|
return false;
|
|
993
987
|
}
|
|
994
988
|
|
|
995
989
|
// Read row size of value
|
|
996
|
-
uint64_t
|
|
997
|
-
io.read_to(&
|
|
998
|
-
const size_t
|
|
999
|
-
if (
|
|
1000
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
990
|
+
uint64_t s_size_row_ref;
|
|
991
|
+
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
|
992
|
+
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
|
993
|
+
if (s_size_row != s_size_row_ref) {
|
|
994
|
+
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
|
1001
995
|
return false;
|
|
1002
996
|
}
|
|
1003
997
|
|
|
1004
998
|
if (cell_count) {
|
|
1005
999
|
// Read and set the values for the whole cell range
|
|
1006
|
-
ggml_backend_tensor_set(
|
|
1000
|
+
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
|
1007
1001
|
}
|
|
1008
1002
|
}
|
|
1009
1003
|
} else {
|
|
1010
1004
|
// For each layer, read the values for each cell (transposed)
|
|
1011
1005
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
1012
|
-
const uint32_t
|
|
1006
|
+
const uint32_t n_embd_s = hparams.n_embd_s();
|
|
1013
1007
|
|
|
1014
1008
|
// Read type of value
|
|
1015
|
-
int32_t
|
|
1016
|
-
io.read_to(&
|
|
1017
|
-
const int32_t
|
|
1018
|
-
if (
|
|
1019
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
1009
|
+
int32_t s_type_i_ref;
|
|
1010
|
+
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
|
1011
|
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
1012
|
+
if (s_type_i != s_type_i_ref) {
|
|
1013
|
+
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
|
1020
1014
|
return false;
|
|
1021
1015
|
}
|
|
1022
1016
|
|
|
1023
1017
|
// Read element size of value
|
|
1024
|
-
uint32_t
|
|
1025
|
-
io.read_to(&
|
|
1026
|
-
const size_t
|
|
1027
|
-
if (
|
|
1028
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
1018
|
+
uint32_t s_size_el_ref;
|
|
1019
|
+
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
|
1020
|
+
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
|
1021
|
+
if (s_size_el != s_size_el_ref) {
|
|
1022
|
+
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
|
1029
1023
|
return false;
|
|
1030
1024
|
}
|
|
1031
1025
|
|
|
1032
|
-
// Read
|
|
1033
|
-
uint32_t
|
|
1034
|
-
io.read_to(&
|
|
1035
|
-
if (
|
|
1036
|
-
LLAMA_LOG_ERROR("%s: mismatched
|
|
1026
|
+
// Read state embedding size
|
|
1027
|
+
uint32_t n_embd_s_ref;
|
|
1028
|
+
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
|
1029
|
+
if (n_embd_s != n_embd_s_ref) {
|
|
1030
|
+
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
|
1037
1031
|
return false;
|
|
1038
1032
|
}
|
|
1039
1033
|
|
|
1040
1034
|
if (cell_count) {
|
|
1041
1035
|
// For each row in the transposed matrix, read the values for the whole cell range
|
|
1042
|
-
for (uint32_t j = 0; j <
|
|
1043
|
-
const size_t dst_offset = (head + j * size) *
|
|
1044
|
-
ggml_backend_tensor_set(
|
|
1036
|
+
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
|
1037
|
+
const size_t dst_offset = (head + j * size) * s_size_el;
|
|
1038
|
+
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
|
1045
1039
|
}
|
|
1046
1040
|
}
|
|
1047
1041
|
}
|
|
@@ -1051,25 +1045,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
|
1051
1045
|
}
|
|
1052
1046
|
|
|
1053
1047
|
//
|
|
1054
|
-
//
|
|
1048
|
+
// llama_memory_recurrent_context
|
|
1055
1049
|
//
|
|
1056
1050
|
|
|
1057
|
-
|
|
1051
|
+
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
|
|
1058
1052
|
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
|
1053
|
+
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
|
1054
|
+
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
|
1062
1055
|
}
|
|
1063
1056
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
llama_sbatch sbatch,
|
|
1068
|
-
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
|
1057
|
+
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
|
1058
|
+
llama_memory_recurrent * mem,
|
|
1059
|
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
|
1069
1060
|
|
|
1070
|
-
|
|
1061
|
+
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
|
|
1071
1062
|
|
|
1072
|
-
bool
|
|
1063
|
+
bool llama_memory_recurrent_context::next() {
|
|
1073
1064
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1074
1065
|
|
|
1075
1066
|
if (++i_next >= ubatches.size()) {
|
|
@@ -1079,54 +1070,48 @@ bool llama_kv_cache_recurrent_state::next() {
|
|
|
1079
1070
|
return true;
|
|
1080
1071
|
}
|
|
1081
1072
|
|
|
1082
|
-
bool
|
|
1073
|
+
bool llama_memory_recurrent_context::apply() {
|
|
1083
1074
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1084
1075
|
|
|
1085
|
-
|
|
1076
|
+
mem->find_slot(ubatches[i_next]);
|
|
1086
1077
|
|
|
1087
1078
|
return true;
|
|
1088
1079
|
}
|
|
1089
1080
|
|
|
1090
|
-
|
|
1091
|
-
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1092
|
-
|
|
1093
|
-
return sbatch.out_ids;
|
|
1094
|
-
}
|
|
1095
|
-
|
|
1096
|
-
llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
|
|
1081
|
+
llama_memory_status llama_memory_recurrent_context::get_status() const {
|
|
1097
1082
|
return status;
|
|
1098
1083
|
}
|
|
1099
1084
|
|
|
1100
|
-
const llama_ubatch &
|
|
1085
|
+
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
|
|
1101
1086
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
1102
1087
|
|
|
1103
1088
|
return ubatches[i_next];
|
|
1104
1089
|
}
|
|
1105
1090
|
|
|
1106
|
-
uint32_t
|
|
1107
|
-
return is_full ?
|
|
1091
|
+
uint32_t llama_memory_recurrent_context::get_n_rs() const {
|
|
1092
|
+
return is_full ? mem->size : mem->n;
|
|
1108
1093
|
}
|
|
1109
1094
|
|
|
1110
|
-
uint32_t
|
|
1111
|
-
return is_full ? 0 :
|
|
1095
|
+
uint32_t llama_memory_recurrent_context::get_head() const {
|
|
1096
|
+
return is_full ? 0 : mem->head;
|
|
1112
1097
|
}
|
|
1113
1098
|
|
|
1114
|
-
|
|
1115
|
-
return
|
|
1099
|
+
int32_t llama_memory_recurrent_context::get_rs_z() const {
|
|
1100
|
+
return is_full ? 0 : mem->rs_z;
|
|
1116
1101
|
}
|
|
1117
1102
|
|
|
1118
|
-
|
|
1119
|
-
return
|
|
1103
|
+
uint32_t llama_memory_recurrent_context::get_size() const {
|
|
1104
|
+
return mem->size;
|
|
1120
1105
|
}
|
|
1121
1106
|
|
|
1122
|
-
ggml_tensor *
|
|
1123
|
-
return
|
|
1107
|
+
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
|
|
1108
|
+
return mem->r_l[il];
|
|
1124
1109
|
}
|
|
1125
1110
|
|
|
1126
|
-
|
|
1127
|
-
return
|
|
1111
|
+
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
|
1112
|
+
return mem->s_l[il];
|
|
1128
1113
|
}
|
|
1129
1114
|
|
|
1130
|
-
|
|
1131
|
-
return
|
|
1115
|
+
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
|
1116
|
+
return mem->cells[i + mem->head].src0;
|
|
1132
1117
|
}
|