@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -87,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
87
87
|
|
|
88
88
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
|
89
89
|
if (pos_bucket) {
|
|
90
|
-
|
|
90
|
+
mctx->set_input_pos_bucket(pos_bucket, ubatch);
|
|
91
91
|
}
|
|
92
92
|
}
|
|
93
93
|
|
|
94
94
|
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
|
95
|
-
|
|
96
|
-
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
|
|
95
|
+
GGML_ASSERT(out_ids);
|
|
97
96
|
|
|
98
|
-
|
|
99
|
-
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
|
|
100
|
-
} else {
|
|
101
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
97
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
102
98
|
|
|
103
|
-
|
|
104
|
-
|
|
99
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
|
|
100
|
+
int32_t * data = (int32_t *) out_ids->data;
|
|
105
101
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
data[0] = n_tokens - 1;
|
|
122
|
-
} else {
|
|
123
|
-
GGML_ASSERT(n_outputs == 0);
|
|
124
|
-
}
|
|
102
|
+
if (n_outputs == n_tokens) {
|
|
103
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
104
|
+
data[i] = i;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
GGML_ASSERT(ubatch->output);
|
|
111
|
+
|
|
112
|
+
int n_outputs = 0;
|
|
113
|
+
|
|
114
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
115
|
+
if (ubatch->output[i]) {
|
|
116
|
+
data[n_outputs++] = i;
|
|
125
117
|
}
|
|
126
118
|
}
|
|
127
119
|
}
|
|
@@ -130,110 +122,97 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
|
130
122
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
|
131
123
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
132
124
|
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
133
|
-
const int64_t
|
|
125
|
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
134
126
|
|
|
135
127
|
GGML_ASSERT(mean);
|
|
136
128
|
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
|
|
137
129
|
|
|
138
130
|
float * data = (float *) mean->data;
|
|
139
|
-
memset(mean->data, 0, n_tokens
|
|
140
|
-
|
|
141
|
-
std::vector<uint64_t> sum(n_tokens, 0);
|
|
131
|
+
memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
|
|
142
132
|
|
|
143
|
-
|
|
144
|
-
for (int
|
|
145
|
-
|
|
133
|
+
std::vector<uint64_t> sums(n_seqs_unq, 0);
|
|
134
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
|
135
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
136
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
137
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
146
138
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
sum[seq_id] += ubatch->n_seq_tokens;
|
|
139
|
+
sums[seq_idx] += ubatch->n_seq_tokens;
|
|
140
|
+
}
|
|
151
141
|
}
|
|
152
142
|
|
|
153
|
-
std::vector<float> div(
|
|
154
|
-
for (int
|
|
155
|
-
const uint64_t
|
|
156
|
-
if (
|
|
157
|
-
div[
|
|
143
|
+
std::vector<float> div(n_seqs_unq, 0.0f);
|
|
144
|
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
|
145
|
+
const uint64_t sum = sums[s];
|
|
146
|
+
if (sum > 0) {
|
|
147
|
+
div[s] = 1.0f/float(sum);
|
|
158
148
|
}
|
|
159
149
|
}
|
|
160
150
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
151
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
|
152
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
153
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
154
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
164
155
|
|
|
165
|
-
|
|
166
|
-
|
|
156
|
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
157
|
+
data[seq_idx*n_tokens + i + j] = div[seq_idx];
|
|
158
|
+
}
|
|
167
159
|
}
|
|
168
160
|
}
|
|
169
161
|
}
|
|
170
162
|
}
|
|
171
163
|
|
|
172
164
|
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
177
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
178
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
165
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
166
|
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
167
|
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
179
168
|
|
|
169
|
+
if (cparams.embeddings && (
|
|
170
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
|
171
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
|
|
172
|
+
)) {
|
|
180
173
|
GGML_ASSERT(cls);
|
|
181
174
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
182
175
|
|
|
183
176
|
uint32_t * data = (uint32_t *) cls->data;
|
|
184
|
-
memset(cls->data, 0,
|
|
185
|
-
|
|
186
|
-
// TODO: fix indexing [UBATCH_IDX]
|
|
187
|
-
for (int s = 0; s < n_seqs; ++s) {
|
|
188
|
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
177
|
+
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
189
178
|
|
|
190
|
-
|
|
191
|
-
|
|
179
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
|
180
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
181
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
182
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
192
183
|
|
|
193
|
-
|
|
194
|
-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
|
|
195
|
-
|
|
196
|
-
if (pos == 0) {
|
|
197
|
-
data[seq_id] = s*n_seq_tokens + i;
|
|
198
|
-
}
|
|
184
|
+
data[seq_idx] = i;
|
|
199
185
|
}
|
|
200
186
|
}
|
|
201
187
|
}
|
|
202
188
|
|
|
203
189
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
|
204
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
205
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
206
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
207
|
-
|
|
208
190
|
GGML_ASSERT(cls);
|
|
209
191
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
210
192
|
|
|
211
193
|
uint32_t * data = (uint32_t *) cls->data;
|
|
212
|
-
memset(cls->data, 0,
|
|
213
|
-
|
|
214
|
-
std::vector<int> last_pos(n_tokens, -1);
|
|
215
|
-
std::vector<int> last_row(n_tokens, -1);
|
|
194
|
+
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
216
195
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
196
|
+
std::vector<int> last_pos(n_seqs_unq, -1);
|
|
197
|
+
std::vector<int> last_row(n_seqs_unq, -1);
|
|
220
198
|
|
|
221
|
-
|
|
222
|
-
|
|
199
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
200
|
+
const llama_pos pos = ubatch->pos[i];
|
|
223
201
|
|
|
224
|
-
for (int
|
|
225
|
-
const
|
|
202
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
203
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
204
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
226
205
|
|
|
227
|
-
if (pos >= last_pos[
|
|
228
|
-
last_pos[
|
|
229
|
-
last_row[
|
|
206
|
+
if (pos >= last_pos[seq_idx]) {
|
|
207
|
+
last_pos[seq_idx] = pos;
|
|
208
|
+
last_row[seq_idx] = i;
|
|
230
209
|
}
|
|
231
210
|
}
|
|
232
211
|
}
|
|
233
212
|
|
|
234
|
-
for (int
|
|
235
|
-
if (last_row[
|
|
236
|
-
data[
|
|
213
|
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
|
214
|
+
if (last_row[s] >= 0) {
|
|
215
|
+
data[s] = last_row[s];
|
|
237
216
|
}
|
|
238
217
|
}
|
|
239
218
|
}
|
|
@@ -242,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
242
221
|
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
243
222
|
GGML_UNUSED(ubatch);
|
|
244
223
|
|
|
245
|
-
const int64_t n_rs =
|
|
224
|
+
const int64_t n_rs = mctx->get_n_rs();
|
|
246
225
|
|
|
247
226
|
if (s_copy) {
|
|
248
227
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
@@ -250,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
|
250
229
|
|
|
251
230
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
252
231
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
253
|
-
data[i] =
|
|
232
|
+
data[i] = mctx->s_copy(i);
|
|
254
233
|
}
|
|
255
234
|
}
|
|
256
235
|
}
|
|
@@ -266,160 +245,99 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
266
245
|
}
|
|
267
246
|
|
|
268
247
|
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
|
281
|
-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
|
282
|
-
|
|
283
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
284
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
|
285
|
-
|
|
286
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
|
287
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
|
288
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
|
289
|
-
float f = -INFINITY;
|
|
290
|
-
|
|
291
|
-
// TODO: fix indexing [UBATCH_IDX]
|
|
292
|
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
|
293
|
-
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
|
294
|
-
if (hparams.use_alibi) {
|
|
295
|
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
|
296
|
-
} else {
|
|
297
|
-
f = 0.0f;
|
|
298
|
-
}
|
|
299
|
-
break;
|
|
300
|
-
}
|
|
301
|
-
}
|
|
302
|
-
|
|
303
|
-
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
|
304
|
-
}
|
|
305
|
-
}
|
|
306
|
-
}
|
|
307
|
-
}
|
|
308
|
-
}
|
|
309
|
-
} else {
|
|
310
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
311
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
312
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
313
|
-
const int64_t n_stride = ubatch->n_tokens;
|
|
314
|
-
|
|
315
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
316
|
-
|
|
317
|
-
float * data = (float *) kq_mask->data;
|
|
318
|
-
|
|
319
|
-
for (int h = 0; h < 1; ++h) {
|
|
320
|
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
|
321
|
-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
|
322
|
-
|
|
323
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
324
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
|
325
|
-
|
|
326
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
|
327
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
|
328
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
|
329
|
-
float f = -INFINITY;
|
|
330
|
-
|
|
331
|
-
// TODO: fix indexing [UBATCH_IDX]
|
|
332
|
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
|
333
|
-
if (ubatch->seq_id[s0][s] == seq_id) {
|
|
334
|
-
if (hparams.use_alibi) {
|
|
335
|
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
|
336
|
-
} else {
|
|
337
|
-
f = 0.0f;
|
|
338
|
-
}
|
|
339
|
-
break;
|
|
340
|
-
}
|
|
341
|
-
}
|
|
342
|
-
|
|
343
|
-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
|
344
|
-
}
|
|
345
|
-
}
|
|
248
|
+
const int64_t n_kv = ubatch->n_tokens;
|
|
249
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
250
|
+
|
|
251
|
+
GGML_ASSERT(kq_mask);
|
|
252
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
253
|
+
|
|
254
|
+
float * data = (float *) kq_mask->data;
|
|
255
|
+
|
|
256
|
+
for (int h = 0; h < 1; ++h) {
|
|
257
|
+
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
258
|
+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
346
259
|
|
|
347
|
-
|
|
348
|
-
|
|
260
|
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
261
|
+
float f = -INFINITY;
|
|
262
|
+
|
|
263
|
+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
|
264
|
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
265
|
+
|
|
266
|
+
// TODO: reimplement this like in llama_kv_cache_unified
|
|
267
|
+
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
|
268
|
+
if (hparams.use_alibi) {
|
|
269
|
+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
|
270
|
+
} else {
|
|
271
|
+
f = 0.0f;
|
|
349
272
|
}
|
|
273
|
+
break;
|
|
350
274
|
}
|
|
351
275
|
}
|
|
276
|
+
|
|
277
|
+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
|
352
278
|
}
|
|
353
279
|
}
|
|
354
280
|
}
|
|
355
281
|
}
|
|
356
282
|
|
|
357
283
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
284
|
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
|
285
|
+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
|
286
|
+
|
|
287
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
361
288
|
}
|
|
362
289
|
|
|
363
290
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
}
|
|
291
|
+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
292
|
+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
367
293
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
294
|
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
295
|
+
|
|
296
|
+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
|
297
|
+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
|
298
|
+
|
|
299
|
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
371
300
|
}
|
|
372
301
|
|
|
373
302
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
374
|
-
|
|
375
|
-
const int64_t n_enc = cross_kq_mask->ne[0];
|
|
376
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
303
|
+
GGML_ASSERT(cross_kq_mask);
|
|
377
304
|
|
|
378
|
-
|
|
379
|
-
|
|
305
|
+
const int64_t n_enc = cross_kq_mask->ne[0];
|
|
306
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
380
307
|
|
|
381
|
-
|
|
308
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
|
309
|
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
|
382
310
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
311
|
+
float * data = (float *) cross_kq_mask->data;
|
|
312
|
+
|
|
313
|
+
for (int h = 0; h < 1; ++h) {
|
|
314
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
315
|
+
for (int j = 0; j < n_enc; ++j) {
|
|
316
|
+
float f = -INFINITY;
|
|
317
|
+
|
|
318
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
319
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
320
|
+
|
|
321
|
+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
|
322
|
+
f = 0.0f;
|
|
393
323
|
}
|
|
394
|
-
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
|
395
324
|
}
|
|
325
|
+
|
|
326
|
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
|
396
327
|
}
|
|
328
|
+
}
|
|
397
329
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
}
|
|
330
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
331
|
+
for (int j = 0; j < n_enc; ++j) {
|
|
332
|
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
402
333
|
}
|
|
403
334
|
}
|
|
404
335
|
}
|
|
405
336
|
}
|
|
406
337
|
|
|
407
338
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
|
|
413
|
-
|
|
414
|
-
if (s_copy) {
|
|
415
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
416
|
-
int32_t * data = (int32_t *) s_copy->data;
|
|
417
|
-
|
|
418
|
-
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
419
|
-
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
420
|
-
data[i] = mem_state->get_state_recr()->s_copy(i);
|
|
421
|
-
}
|
|
422
|
-
}
|
|
339
|
+
inp_attn->set_input(ubatch);
|
|
340
|
+
inp_rs->set_input(ubatch);
|
|
423
341
|
}
|
|
424
342
|
|
|
425
343
|
//
|
|
@@ -461,16 +379,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
461
379
|
backend_cpu (params.backend_cpu),
|
|
462
380
|
cvec (params.cvec),
|
|
463
381
|
loras (params.loras),
|
|
464
|
-
|
|
382
|
+
mctx (params.mctx),
|
|
465
383
|
cross (params.cross),
|
|
466
384
|
cb_func (params.cb),
|
|
467
385
|
res (std::make_unique<llm_graph_result>()) {
|
|
468
386
|
}
|
|
469
387
|
|
|
470
|
-
int64_t llm_graph_context::n_pos_per_embd() const {
|
|
471
|
-
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
|
472
|
-
}
|
|
473
|
-
|
|
474
388
|
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
475
389
|
if (cb_func) {
|
|
476
390
|
cb_func(ubatch, cur, name, il);
|
|
@@ -630,12 +544,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
630
544
|
|
|
631
545
|
switch (type_op) {
|
|
632
546
|
case LLM_FFN_SILU:
|
|
633
|
-
{
|
|
547
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
548
|
+
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
|
549
|
+
cb(cur, "ffn_swiglu", il);
|
|
550
|
+
type_gate = LLM_FFN_SEQ;
|
|
551
|
+
} else {
|
|
634
552
|
cur = ggml_silu(ctx0, cur);
|
|
635
553
|
cb(cur, "ffn_silu", il);
|
|
636
554
|
} break;
|
|
637
555
|
case LLM_FFN_GELU:
|
|
638
|
-
{
|
|
556
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
557
|
+
cur = ggml_geglu_split(ctx0, cur, tmp);
|
|
558
|
+
cb(cur, "ffn_geglu", il);
|
|
559
|
+
type_gate = LLM_FFN_SEQ;
|
|
560
|
+
} else {
|
|
639
561
|
cur = ggml_gelu(ctx0, cur);
|
|
640
562
|
cb(cur, "ffn_gelu", il);
|
|
641
563
|
if (act_scales != NULL) {
|
|
@@ -644,7 +566,11 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
644
566
|
}
|
|
645
567
|
} break;
|
|
646
568
|
case LLM_FFN_RELU:
|
|
647
|
-
{
|
|
569
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
570
|
+
cur = ggml_reglu_split(ctx0, cur, tmp);
|
|
571
|
+
cb(cur, "ffn_reglu", il);
|
|
572
|
+
type_gate = LLM_FFN_SEQ;
|
|
573
|
+
} else {
|
|
648
574
|
cur = ggml_relu(ctx0, cur);
|
|
649
575
|
cb(cur, "ffn_relu", il);
|
|
650
576
|
} break;
|
|
@@ -658,32 +584,19 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
658
584
|
} break;
|
|
659
585
|
case LLM_FFN_SWIGLU:
|
|
660
586
|
{
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
664
|
-
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
665
|
-
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
666
|
-
|
|
667
|
-
x0 = ggml_silu(ctx0, x0);
|
|
668
|
-
cb(cur, "ffn_silu", il);
|
|
669
|
-
|
|
670
|
-
cur = ggml_mul(ctx0, x0, x1);
|
|
671
|
-
cb(cur, "ffn_mul", il);
|
|
587
|
+
cur = ggml_swiglu(ctx0, cur);
|
|
588
|
+
cb(cur, "ffn_swiglu", il);
|
|
672
589
|
} break;
|
|
673
590
|
case LLM_FFN_GEGLU:
|
|
674
591
|
{
|
|
675
|
-
|
|
676
|
-
int64_t split_point = cur->ne[0] / 2;
|
|
677
|
-
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
678
|
-
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
679
|
-
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
680
|
-
|
|
681
|
-
x0 = ggml_gelu(ctx0, x0);
|
|
682
|
-
cb(x0, "ffn_gelu", il);
|
|
683
|
-
|
|
684
|
-
cur = ggml_mul(ctx0, x0, x1);
|
|
592
|
+
cur = ggml_geglu(ctx0, cur);
|
|
685
593
|
cb(cur, "ffn_geglu", il);
|
|
686
594
|
} break;
|
|
595
|
+
case LLM_FFN_REGLU:
|
|
596
|
+
{
|
|
597
|
+
cur = ggml_reglu(ctx0, cur);
|
|
598
|
+
cb(cur, "ffn_reglu", il);
|
|
599
|
+
} break;
|
|
687
600
|
}
|
|
688
601
|
|
|
689
602
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -813,12 +726,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
813
726
|
|
|
814
727
|
switch (type_op) {
|
|
815
728
|
case LLM_FFN_SILU:
|
|
816
|
-
{
|
|
729
|
+
if (gate_exps) {
|
|
730
|
+
cur = ggml_swiglu_split(ctx0, cur, up);
|
|
731
|
+
cb(cur, "ffn_moe_swiglu", il);
|
|
732
|
+
} else {
|
|
817
733
|
cur = ggml_silu(ctx0, cur);
|
|
818
734
|
cb(cur, "ffn_moe_silu", il);
|
|
819
735
|
} break;
|
|
820
736
|
case LLM_FFN_GELU:
|
|
821
|
-
{
|
|
737
|
+
if (gate_exps) {
|
|
738
|
+
cur = ggml_geglu_split(ctx0, cur, up);
|
|
739
|
+
cb(cur, "ffn_moe_geglu", il);
|
|
740
|
+
} else {
|
|
822
741
|
cur = ggml_gelu(ctx0, cur);
|
|
823
742
|
cb(cur, "ffn_moe_gelu", il);
|
|
824
743
|
} break;
|
|
@@ -826,11 +745,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
826
745
|
GGML_ABORT("fatal error");
|
|
827
746
|
}
|
|
828
747
|
|
|
829
|
-
if (gate_exps) {
|
|
830
|
-
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
|
831
|
-
cb(cur, "ffn_moe_gate_par", il);
|
|
832
|
-
}
|
|
833
|
-
|
|
834
748
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
|
835
749
|
cb(experts, "ffn_moe_down", il);
|
|
836
750
|
|
|
@@ -915,11 +829,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
915
829
|
}
|
|
916
830
|
|
|
917
831
|
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
|
918
|
-
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
|
832
|
+
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
|
|
919
833
|
|
|
920
834
|
auto & cur = inp->pos;
|
|
921
835
|
|
|
922
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
|
836
|
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
|
|
923
837
|
ggml_set_input(cur);
|
|
924
838
|
|
|
925
839
|
res->add_input(std::move(inp));
|
|
@@ -942,6 +856,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
|
942
856
|
}
|
|
943
857
|
|
|
944
858
|
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
|
859
|
+
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
|
|
860
|
+
// but this would make the graph topology depend on the number of output tokens, which can interere with
|
|
861
|
+
// features that require constant topology such as pipline parallelism
|
|
862
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
|
|
863
|
+
//if (n_outputs < n_tokens) {
|
|
864
|
+
// return nullptr;
|
|
865
|
+
//}
|
|
866
|
+
|
|
945
867
|
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
|
946
868
|
|
|
947
869
|
auto & cur = inp->out_ids;
|
|
@@ -959,7 +881,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|
|
959
881
|
|
|
960
882
|
auto & cur = inp->mean;
|
|
961
883
|
|
|
962
|
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens,
|
|
884
|
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
|
|
963
885
|
ggml_set_input(cur);
|
|
964
886
|
|
|
965
887
|
res->add_input(std::move(inp));
|
|
@@ -972,7 +894,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
|
972
894
|
|
|
973
895
|
auto & cur = inp->cls;
|
|
974
896
|
|
|
975
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32,
|
|
897
|
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
|
|
976
898
|
ggml_set_input(cur);
|
|
977
899
|
|
|
978
900
|
res->add_input(std::move(inp));
|
|
@@ -1018,11 +940,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
|
1018
940
|
}
|
|
1019
941
|
|
|
1020
942
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
1021
|
-
const auto *
|
|
943
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
|
1022
944
|
|
|
1023
|
-
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams,
|
|
945
|
+
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
|
1024
946
|
|
|
1025
|
-
const auto n_kv =
|
|
947
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
|
1026
948
|
|
|
1027
949
|
auto & cur = inp->pos_bucket;
|
|
1028
950
|
|
|
@@ -1049,33 +971,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
|
1049
971
|
return pos_bias;
|
|
1050
972
|
}
|
|
1051
973
|
|
|
1052
|
-
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
1053
|
-
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
|
|
1054
|
-
|
|
1055
|
-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
|
|
1056
|
-
|
|
1057
|
-
{
|
|
1058
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
|
1059
|
-
|
|
1060
|
-
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
|
|
1061
|
-
|
|
1062
|
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1063
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1064
|
-
ggml_set_input(inp->self_kq_mask);
|
|
1065
|
-
|
|
1066
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1067
|
-
}
|
|
1068
|
-
|
|
1069
|
-
{
|
|
1070
|
-
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
|
|
1071
|
-
|
|
1072
|
-
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1073
|
-
ggml_set_input(inp->s_copy);
|
|
1074
|
-
}
|
|
1075
|
-
|
|
1076
|
-
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
1077
|
-
}
|
|
1078
|
-
|
|
1079
974
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1080
975
|
ggml_cgraph * gf,
|
|
1081
976
|
ggml_tensor * q,
|
|
@@ -1197,8 +1092,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
|
1197
1092
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
|
1198
1093
|
|
|
1199
1094
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
|
1200
|
-
inp->kq_mask =
|
|
1201
|
-
//cb(inp_kq_mask, "KQ_mask", -1);
|
|
1095
|
+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1202
1096
|
ggml_set_input(inp->kq_mask);
|
|
1203
1097
|
|
|
1204
1098
|
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
|
@@ -1250,23 +1144,38 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1250
1144
|
return cur;
|
|
1251
1145
|
}
|
|
1252
1146
|
|
|
1253
|
-
llm_graph_input_attn_kv_unified
|
|
1254
|
-
|
|
1147
|
+
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
|
|
1148
|
+
ggml_context * ctx0,
|
|
1149
|
+
const llama_ubatch & ubatch,
|
|
1150
|
+
const llama_hparams & hparams,
|
|
1151
|
+
const llama_cparams & cparams,
|
|
1152
|
+
const llama_kv_cache_unified_context * mctx_cur) {
|
|
1255
1153
|
|
|
1256
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams,
|
|
1154
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
|
1257
1155
|
|
|
1258
1156
|
{
|
|
1259
1157
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
|
1260
1158
|
|
|
1261
|
-
const auto n_kv =
|
|
1159
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
|
1160
|
+
const auto n_tokens = ubatch.n_tokens;
|
|
1262
1161
|
|
|
1263
|
-
inp->
|
|
1264
|
-
|
|
1162
|
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
1163
|
+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
|
1164
|
+
|
|
1165
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1265
1166
|
ggml_set_input(inp->self_kq_mask);
|
|
1266
1167
|
|
|
1267
1168
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1268
1169
|
}
|
|
1269
1170
|
|
|
1171
|
+
return inp;
|
|
1172
|
+
}
|
|
1173
|
+
|
|
1174
|
+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
|
1175
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
|
1176
|
+
|
|
1177
|
+
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
|
1178
|
+
|
|
1270
1179
|
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
|
1271
1180
|
}
|
|
1272
1181
|
|
|
@@ -1288,19 +1197,22 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1288
1197
|
ggml_build_forward_expand(gf, k_cur);
|
|
1289
1198
|
ggml_build_forward_expand(gf, v_cur);
|
|
1290
1199
|
|
|
1291
|
-
const auto *
|
|
1200
|
+
const auto * mctx_cur = inp->mctx;
|
|
1292
1201
|
|
|
1293
1202
|
// store to KV cache
|
|
1294
1203
|
{
|
|
1295
|
-
|
|
1296
|
-
|
|
1204
|
+
const auto & k_idxs = inp->get_k_idxs();
|
|
1205
|
+
const auto & v_idxs = inp->get_v_idxs();
|
|
1206
|
+
|
|
1207
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1208
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1297
1209
|
}
|
|
1298
1210
|
|
|
1299
1211
|
const auto & kq_mask = inp->get_kq_mask();
|
|
1300
1212
|
|
|
1301
1213
|
ggml_tensor * q = q_cur;
|
|
1302
|
-
ggml_tensor * k =
|
|
1303
|
-
ggml_tensor * v =
|
|
1214
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1215
|
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1304
1216
|
|
|
1305
1217
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1306
1218
|
cb(cur, "kqv_out", il);
|
|
@@ -1335,26 +1247,39 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1335
1247
|
// these nodes are added to the graph together so that they are not reordered
|
|
1336
1248
|
// by doing so, the number of splits in the graph is reduced
|
|
1337
1249
|
ggml_build_forward_expand(gf, q_cur);
|
|
1338
|
-
ggml_build_forward_expand(gf, k_cur);
|
|
1339
|
-
ggml_build_forward_expand(gf, v_cur);
|
|
1340
1250
|
|
|
1341
|
-
|
|
1251
|
+
if (k_cur) {
|
|
1252
|
+
ggml_build_forward_expand(gf, k_cur);
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
if (v_cur) {
|
|
1256
|
+
ggml_build_forward_expand(gf, v_cur);
|
|
1257
|
+
}
|
|
1258
|
+
|
|
1259
|
+
const auto * mctx_iswa = inp->mctx;
|
|
1342
1260
|
|
|
1343
1261
|
const bool is_swa = hparams.is_swa(il);
|
|
1344
1262
|
|
|
1345
|
-
const auto *
|
|
1263
|
+
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
|
1346
1264
|
|
|
1347
|
-
// store to KV cache
|
|
1348
|
-
{
|
|
1349
|
-
|
|
1350
|
-
|
|
1265
|
+
// optionally store to KV cache
|
|
1266
|
+
if (k_cur) {
|
|
1267
|
+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
|
1268
|
+
|
|
1269
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
|
1270
|
+
}
|
|
1271
|
+
|
|
1272
|
+
if (v_cur) {
|
|
1273
|
+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
|
1274
|
+
|
|
1275
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
|
1351
1276
|
}
|
|
1352
1277
|
|
|
1353
1278
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
1354
1279
|
|
|
1355
1280
|
ggml_tensor * q = q_cur;
|
|
1356
|
-
ggml_tensor * k =
|
|
1357
|
-
ggml_tensor * v =
|
|
1281
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1282
|
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1358
1283
|
|
|
1359
1284
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1360
1285
|
cb(cur, "kqv_out", il);
|
|
@@ -1379,7 +1304,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
|
1379
1304
|
|
|
1380
1305
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
|
1381
1306
|
|
|
1382
|
-
inp->cross_kq_mask =
|
|
1307
|
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1383
1308
|
ggml_set_input(inp->cross_kq_mask);
|
|
1384
1309
|
|
|
1385
1310
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
|
@@ -1429,66 +1354,21 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1429
1354
|
return cur;
|
|
1430
1355
|
}
|
|
1431
1356
|
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
ggml_tensor * wo,
|
|
1436
|
-
ggml_tensor * wo_b,
|
|
1437
|
-
ggml_tensor * q_cur,
|
|
1438
|
-
ggml_tensor * k_cur,
|
|
1439
|
-
ggml_tensor * v_cur,
|
|
1440
|
-
ggml_tensor * kq_b,
|
|
1441
|
-
ggml_tensor * v_mla,
|
|
1442
|
-
float kq_scale,
|
|
1443
|
-
int il) const {
|
|
1444
|
-
// these nodes are added to the graph together so that they are not reordered
|
|
1445
|
-
// by doing so, the number of splits in the graph is reduced
|
|
1446
|
-
ggml_build_forward_expand(gf, q_cur);
|
|
1447
|
-
ggml_build_forward_expand(gf, k_cur);
|
|
1448
|
-
ggml_build_forward_expand(gf, v_cur);
|
|
1449
|
-
|
|
1450
|
-
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
|
|
1451
|
-
|
|
1452
|
-
// store to KV cache
|
|
1453
|
-
{
|
|
1454
|
-
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
|
1455
|
-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
|
1456
|
-
}
|
|
1457
|
-
|
|
1458
|
-
const auto & kq_mask = inp->get_kq_mask();
|
|
1459
|
-
|
|
1460
|
-
ggml_tensor * q = q_cur;
|
|
1461
|
-
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
|
1462
|
-
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
|
1463
|
-
|
|
1464
|
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1465
|
-
cb(cur, "kqv_out", il);
|
|
1466
|
-
|
|
1467
|
-
if (wo) {
|
|
1468
|
-
cur = build_lora_mm(wo, cur);
|
|
1469
|
-
if (arch == LLM_ARCH_GLM4) {
|
|
1470
|
-
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
1471
|
-
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1472
|
-
}
|
|
1473
|
-
}
|
|
1474
|
-
|
|
1475
|
-
if (wo_b) {
|
|
1476
|
-
cur = ggml_add(ctx0, cur, wo_b);
|
|
1477
|
-
}
|
|
1478
|
-
|
|
1479
|
-
return cur;
|
|
1480
|
-
}
|
|
1481
|
-
|
|
1357
|
+
// TODO: maybe separate the inner implementation into a separate function
|
|
1358
|
+
// like with the non-sliding window equivalent
|
|
1359
|
+
// once sliding-window hybrid caches are a thing.
|
|
1482
1360
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
1483
|
-
const auto *
|
|
1361
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
|
1484
1362
|
|
|
1485
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams,
|
|
1363
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
|
1486
1364
|
|
|
1487
1365
|
{
|
|
1488
|
-
const auto n_kv =
|
|
1366
|
+
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1489
1367
|
|
|
1490
|
-
inp->
|
|
1491
|
-
|
|
1368
|
+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
1369
|
+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
1370
|
+
|
|
1371
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1492
1372
|
ggml_set_input(inp->self_kq_mask);
|
|
1493
1373
|
|
|
1494
1374
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1497,10 +1377,12 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1497
1377
|
{
|
|
1498
1378
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
1499
1379
|
|
|
1500
|
-
const auto n_kv =
|
|
1380
|
+
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1381
|
+
|
|
1382
|
+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
1383
|
+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
1501
1384
|
|
|
1502
|
-
inp->self_kq_mask_swa =
|
|
1503
|
-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
1385
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
1504
1386
|
ggml_set_input(inp->self_kq_mask_swa);
|
|
1505
1387
|
|
|
1506
1388
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
@@ -1519,7 +1401,7 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1519
1401
|
uint32_t kv_head,
|
|
1520
1402
|
uint32_t kv_size,
|
|
1521
1403
|
int32_t rs_zero,
|
|
1522
|
-
|
|
1404
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1523
1405
|
|
|
1524
1406
|
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
|
1525
1407
|
|
|
@@ -1528,19 +1410,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1528
1410
|
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
|
1529
1411
|
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
|
1530
1412
|
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1537
|
-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1538
|
-
ggml_build_forward_expand(gf, output_states);
|
|
1539
|
-
} else {
|
|
1540
|
-
// FIXME: make the gathering operation happen before the copy below
|
|
1541
|
-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
|
1542
|
-
output_states = states;
|
|
1543
|
-
}
|
|
1413
|
+
// copy states
|
|
1414
|
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
1415
|
+
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1416
|
+
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1417
|
+
ggml_build_forward_expand(gf, output_states);
|
|
1544
1418
|
|
|
1545
1419
|
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
|
1546
1420
|
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
@@ -1552,41 +1426,38 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1552
1426
|
return output_states;
|
|
1553
1427
|
}
|
|
1554
1428
|
|
|
1555
|
-
llm_graph_input_rs
|
|
1556
|
-
|
|
1429
|
+
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
|
1430
|
+
ggml_context * ctx0,
|
|
1431
|
+
const llama_memory_recurrent_context * mctx_cur) {
|
|
1557
1432
|
|
|
1558
|
-
auto inp = std::make_unique<llm_graph_input_rs>(
|
|
1433
|
+
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
|
1559
1434
|
|
|
1560
|
-
const auto n_rs =
|
|
1435
|
+
const auto n_rs = mctx_cur->get_n_rs();
|
|
1561
1436
|
|
|
1562
1437
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1563
1438
|
ggml_set_input(inp->s_copy);
|
|
1564
1439
|
|
|
1565
|
-
return
|
|
1440
|
+
return inp;
|
|
1566
1441
|
}
|
|
1567
1442
|
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
int32_t state_size,
|
|
1573
|
-
int32_t n_seqs,
|
|
1574
|
-
bool avoid_copies) const {
|
|
1575
|
-
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
|
1443
|
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
1444
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1445
|
+
|
|
1446
|
+
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
|
|
1576
1447
|
|
|
1577
|
-
return
|
|
1448
|
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
|
1578
1449
|
}
|
|
1579
1450
|
|
|
1580
1451
|
ggml_tensor * llm_graph_context::build_rs(
|
|
1581
|
-
|
|
1452
|
+
llm_graph_input_rs * inp,
|
|
1582
1453
|
ggml_cgraph * gf,
|
|
1583
1454
|
ggml_tensor * s,
|
|
1584
1455
|
int32_t state_size,
|
|
1585
1456
|
int32_t n_seqs,
|
|
1586
|
-
|
|
1587
|
-
const auto * kv_state =
|
|
1457
|
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1458
|
+
const auto * kv_state = inp->mctx;
|
|
1588
1459
|
|
|
1589
|
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
|
|
1460
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
|
1590
1461
|
}
|
|
1591
1462
|
|
|
1592
1463
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
@@ -1594,13 +1465,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
1594
1465
|
ggml_cgraph * gf,
|
|
1595
1466
|
const llama_ubatch & ubatch,
|
|
1596
1467
|
int il) const {
|
|
1597
|
-
const auto *
|
|
1468
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1598
1469
|
|
|
1599
1470
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1600
1471
|
|
|
1601
1472
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
1602
1473
|
|
|
1603
|
-
ggml_tensor * token_shift_all =
|
|
1474
|
+
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
|
1604
1475
|
|
|
1605
1476
|
ggml_tensor * token_shift = build_rs(
|
|
1606
1477
|
inp, gf, token_shift_all,
|
|
@@ -1615,22 +1486,33 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1615
1486
|
ggml_tensor * token_shift,
|
|
1616
1487
|
const llama_ubatch & ubatch,
|
|
1617
1488
|
int il) const {
|
|
1618
|
-
const auto *
|
|
1489
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1619
1490
|
|
|
1620
1491
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1621
1492
|
const auto n_embd = hparams.n_embd;
|
|
1622
1493
|
|
|
1623
1494
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
1624
1495
|
|
|
1625
|
-
const auto kv_head =
|
|
1496
|
+
const auto kv_head = mctx_cur->get_head();
|
|
1626
1497
|
|
|
1627
1498
|
return ggml_cpy(
|
|
1628
1499
|
ctx0,
|
|
1629
1500
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
|
1630
|
-
ggml_view_1d(ctx0,
|
|
1501
|
+
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
|
|
1631
1502
|
);
|
|
1632
1503
|
}
|
|
1633
1504
|
|
|
1505
|
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
1506
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
1507
|
+
|
|
1508
|
+
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
|
|
1509
|
+
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
|
1510
|
+
|
|
1511
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
1512
|
+
|
|
1513
|
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
1514
|
+
}
|
|
1515
|
+
|
|
1634
1516
|
void llm_graph_context::build_pooling(
|
|
1635
1517
|
ggml_cgraph * gf,
|
|
1636
1518
|
ggml_tensor * cls,
|