cui-llama.rn 1.6.0 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -7
- package/android/src/main/CMakeLists.txt +22 -11
- package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
- package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
- package/android/src/main/jni.cpp +173 -18
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/LICENSE +21 -0
- package/cpp/chat.cpp +129 -107
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +58 -78
- package/cpp/common.h +29 -21
- package/cpp/ggml-alloc.c +4 -1
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpp.h +1 -1
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
- package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
- package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
- package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
- package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
- package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
- package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
- package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
- package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
- package/cpp/ggml-cpu.h +5 -0
- package/cpp/ggml-impl.h +16 -9
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +810 -176
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +227 -282
- package/cpp/ggml.h +82 -101
- package/cpp/gguf.cpp +33 -33
- package/cpp/json-schema-to-grammar.cpp +3 -0
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +49 -17
- package/cpp/llama-arch.h +9 -0
- package/cpp/llama-batch.cpp +8 -2
- package/cpp/llama-batch.h +2 -1
- package/cpp/llama-chat.cpp +39 -16
- package/cpp/llama-chat.h +4 -2
- package/cpp/llama-context.cpp +440 -611
- package/cpp/llama-context.h +44 -33
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +214 -291
- package/cpp/llama-graph.h +69 -21
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +39 -5
- package/cpp/llama-kv-cache.cpp +2067 -620
- package/cpp/llama-kv-cache.h +410 -108
- package/cpp/llama-memory.h +12 -1
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +1089 -359
- package/cpp/llama-model.h +19 -3
- package/cpp/llama-sampling.cpp +20 -7
- package/cpp/llama-vocab.cpp +54 -9
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +86 -142
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +602 -190
- package/cpp/rn-llama.h +34 -8
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +20 -10
- package/ios/RNLlama.h +6 -0
- package/ios/RNLlama.mm +82 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +131 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +54 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +72 -4
- package/src/index.ts +212 -38
- package/cpp/binary-ops.h +0 -16
- package/cpp/ops.h +0 -128
- package/cpp/simd-mappings.h +0 -888
- package/cpp/unary-ops.h +0 -28
- package/cpp/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
- /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
- /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
- /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
- /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
- /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
- /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
- /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
- /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
package/cpp/llama-graph.cpp
CHANGED
@@ -9,33 +9,6 @@
|
|
9
9
|
#include <cmath>
|
10
10
|
#include <cstring>
|
11
11
|
|
12
|
-
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
13
|
-
// TODO move to hparams if a T5 variant appears that uses a different value
|
14
|
-
const int64_t max_distance = 128;
|
15
|
-
|
16
|
-
if (bidirectional) {
|
17
|
-
n_buckets >>= 1;
|
18
|
-
}
|
19
|
-
|
20
|
-
const int64_t max_exact = n_buckets >> 1;
|
21
|
-
|
22
|
-
int32_t relative_position = x - y;
|
23
|
-
int32_t relative_bucket = 0;
|
24
|
-
|
25
|
-
if (bidirectional) {
|
26
|
-
relative_bucket += (relative_position > 0) * n_buckets;
|
27
|
-
relative_position = abs(relative_position);
|
28
|
-
} else {
|
29
|
-
relative_position = -std::min<int32_t>(relative_position, 0);
|
30
|
-
}
|
31
|
-
|
32
|
-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
33
|
-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
34
|
-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
35
|
-
|
36
|
-
return relative_bucket;
|
37
|
-
}
|
38
|
-
|
39
12
|
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
40
13
|
if (ubatch->token) {
|
41
14
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -55,7 +28,21 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
55
28
|
if (ubatch->pos && pos) {
|
56
29
|
const int64_t n_tokens = ubatch->n_tokens;
|
57
30
|
|
58
|
-
|
31
|
+
if (ubatch->token && n_pos_per_embd == 4) {
|
32
|
+
// in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
|
33
|
+
// the 3 first dims are the same, and 4th dim is all 0
|
34
|
+
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
|
35
|
+
// copy the first dimension
|
36
|
+
for (int i = 0; i < n_tokens; ++i) {
|
37
|
+
pos_data[ i] = ubatch->pos[i];
|
38
|
+
pos_data[ n_tokens + i] = ubatch->pos[i];
|
39
|
+
pos_data[2 * n_tokens + i] = ubatch->pos[i];
|
40
|
+
pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
|
41
|
+
}
|
42
|
+
lm_ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*lm_ggml_element_size(pos));
|
43
|
+
} else {
|
44
|
+
lm_ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*lm_ggml_element_size(pos));
|
45
|
+
}
|
59
46
|
}
|
60
47
|
}
|
61
48
|
|
@@ -71,7 +58,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|
71
58
|
) * f_attn_temp_scale + 1.0;
|
72
59
|
}
|
73
60
|
|
74
|
-
lm_ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*
|
61
|
+
lm_ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*lm_ggml_element_size(attn_scale));
|
75
62
|
}
|
76
63
|
}
|
77
64
|
|
@@ -96,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
96
83
|
|
97
84
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
98
85
|
if (pos_bucket) {
|
99
|
-
|
100
|
-
|
101
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(pos_bucket->buffer));
|
102
|
-
LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
103
|
-
|
104
|
-
int32_t * data = (int32_t *) pos_bucket->data;
|
105
|
-
|
106
|
-
const int64_t n_kv = kv_self->n;
|
107
|
-
|
108
|
-
for (int h = 0; h < 1; ++h) {
|
109
|
-
for (int j = 0; j < n_tokens; ++j) {
|
110
|
-
for (int i = 0; i < n_kv; ++i) {
|
111
|
-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
112
|
-
}
|
113
|
-
}
|
114
|
-
}
|
86
|
+
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
|
115
87
|
}
|
116
88
|
}
|
117
89
|
|
@@ -270,24 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
270
242
|
|
271
243
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
272
244
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
273
|
-
|
274
|
-
|
275
|
-
//////////////////////////////////////////////
|
276
|
-
// TODO: this should not mutate the KV cache !
|
277
|
-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
278
|
-
|
279
|
-
// prevent out-of-bound sources
|
280
|
-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
281
|
-
kv_cell.src = cell_id;
|
282
|
-
}
|
283
|
-
|
284
|
-
data[i] = kv_cell.src;
|
285
|
-
|
286
|
-
// TODO: do not mutate the KV cache
|
287
|
-
// ensure copy only happens once
|
288
|
-
if (kv_cell.src != (int32_t) cell_id) {
|
289
|
-
kv_cell.src = cell_id;
|
290
|
-
}
|
245
|
+
data[i] = kv_self->s_copy(i);
|
291
246
|
}
|
292
247
|
}
|
293
248
|
}
|
@@ -303,18 +258,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
303
258
|
|
304
259
|
// clear unused states
|
305
260
|
for (int i = 0; i < n_kv; ++i) {
|
306
|
-
|
307
|
-
|
308
|
-
//////////////////////////////////////////////
|
309
|
-
// TODO: this should not mutate the KV cache !
|
310
|
-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
311
|
-
|
312
|
-
data[i] = (float) (kv_cell.src >= 0);
|
313
|
-
|
314
|
-
// only clear once
|
315
|
-
if (kv_cell.src < 0) {
|
316
|
-
kv_cell.src = cell_id;
|
317
|
-
}
|
261
|
+
data[i] = kv_self->s_mask(i);
|
318
262
|
}
|
319
263
|
}
|
320
264
|
}
|
@@ -417,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
417
361
|
}
|
418
362
|
|
419
363
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
420
|
-
if (self_kq_mask
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
425
|
-
|
426
|
-
float * data = nullptr;
|
427
|
-
float * data_swa = nullptr;
|
428
|
-
|
429
|
-
if (self_kq_mask) {
|
430
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
431
|
-
data = (float *) self_kq_mask->data;
|
432
|
-
}
|
433
|
-
|
434
|
-
if (self_kq_mask_swa) {
|
435
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
436
|
-
data_swa = (float *) self_kq_mask_swa->data;
|
437
|
-
}
|
438
|
-
|
439
|
-
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
440
|
-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
441
|
-
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
442
|
-
// Causal mask:
|
443
|
-
// xxx-------
|
444
|
-
// xxxx------
|
445
|
-
// xxxxx-----
|
446
|
-
// Non-causal mask:
|
447
|
-
// xxxxx-----
|
448
|
-
// xxxxx-----
|
449
|
-
// xxxxx-----
|
450
|
-
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
451
|
-
for (int h = 0; h < 1; ++h) {
|
452
|
-
for (int s = 0; s < n_seqs; ++s) {
|
453
|
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
454
|
-
|
455
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
456
|
-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
457
|
-
for (int i = 0; i < n_kv; ++i) {
|
458
|
-
float f;
|
459
|
-
// mask the token if:
|
460
|
-
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
461
|
-
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
462
|
-
) {
|
463
|
-
f = -INFINITY;
|
464
|
-
} else {
|
465
|
-
if (hparams.use_alibi) {
|
466
|
-
f = -std::abs(kv_self->cells[i].pos - pos);
|
467
|
-
} else {
|
468
|
-
f = 0.0f;
|
469
|
-
}
|
470
|
-
}
|
471
|
-
|
472
|
-
if (data) {
|
473
|
-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
474
|
-
}
|
475
|
-
|
476
|
-
// may need to cut off old tokens for sliding window
|
477
|
-
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
478
|
-
if (data_swa) {
|
479
|
-
if (hparams.n_attn_chunk) {
|
480
|
-
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
481
|
-
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
482
|
-
f = -INFINITY;
|
483
|
-
}
|
484
|
-
} else {
|
485
|
-
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
486
|
-
f = -INFINITY;
|
487
|
-
}
|
488
|
-
}
|
489
|
-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
490
|
-
}
|
491
|
-
}
|
492
|
-
}
|
493
|
-
}
|
364
|
+
if (self_kq_mask) {
|
365
|
+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
366
|
+
}
|
367
|
+
}
|
494
368
|
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
500
|
-
}
|
501
|
-
}
|
502
|
-
}
|
369
|
+
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
370
|
+
if (self_kq_mask) {
|
371
|
+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
372
|
+
}
|
503
373
|
|
504
|
-
|
505
|
-
|
506
|
-
for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
|
507
|
-
for (int j = 0; j < n_kv; ++j) {
|
508
|
-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
509
|
-
}
|
510
|
-
}
|
511
|
-
}
|
512
|
-
}
|
374
|
+
if (self_kq_mask_swa) {
|
375
|
+
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
513
376
|
}
|
514
377
|
}
|
515
378
|
|
@@ -559,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
559
422
|
n_layer (hparams.n_layer),
|
560
423
|
n_rot (hparams.n_rot),
|
561
424
|
n_ctx (cparams.n_ctx),
|
562
|
-
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
|
563
425
|
n_head (hparams.n_head()),
|
564
426
|
n_head_kv (hparams.n_head_kv()),
|
565
427
|
n_embd_head_k (hparams.n_embd_head_k),
|
@@ -592,7 +454,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
592
454
|
res (std::make_unique<llm_graph_result>()) {
|
593
455
|
}
|
594
456
|
|
595
|
-
int64_t llm_graph_context::
|
457
|
+
int64_t llm_graph_context::n_pos_per_embd() const {
|
596
458
|
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
597
459
|
}
|
598
460
|
|
@@ -796,13 +658,17 @@ lm_ggml_tensor * llm_graph_context::build_ffn(
|
|
796
658
|
} break;
|
797
659
|
}
|
798
660
|
|
799
|
-
if (type_gate == LLM_FFN_PAR) {
|
661
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
800
662
|
cur = lm_ggml_mul(ctx0, cur, tmp);
|
801
663
|
cb(cur, "ffn_gate_par", il);
|
802
664
|
}
|
803
665
|
|
804
666
|
if (down) {
|
805
667
|
cur = build_lora_mm(down, cur);
|
668
|
+
if (arch == LLM_ARCH_GLM4) {
|
669
|
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
670
|
+
lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
|
671
|
+
}
|
806
672
|
}
|
807
673
|
|
808
674
|
if (down_b) {
|
@@ -910,28 +776,35 @@ lm_ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
910
776
|
lm_ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
911
777
|
cb(up, "ffn_moe_up", il);
|
912
778
|
|
913
|
-
lm_ggml_tensor *
|
914
|
-
|
779
|
+
lm_ggml_tensor * experts = nullptr;
|
780
|
+
if (gate_exps) {
|
781
|
+
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
782
|
+
cb(cur, "ffn_moe_gate", il);
|
783
|
+
} else {
|
784
|
+
cur = up;
|
785
|
+
}
|
915
786
|
|
916
787
|
switch (type_op) {
|
917
788
|
case LLM_FFN_SILU:
|
918
789
|
{
|
919
|
-
|
920
|
-
cb(
|
790
|
+
cur = lm_ggml_silu(ctx0, cur);
|
791
|
+
cb(cur, "ffn_moe_silu", il);
|
921
792
|
} break;
|
922
793
|
case LLM_FFN_GELU:
|
923
794
|
{
|
924
|
-
|
925
|
-
cb(
|
795
|
+
cur = lm_ggml_gelu(ctx0, cur);
|
796
|
+
cb(cur, "ffn_moe_gelu", il);
|
926
797
|
} break;
|
927
798
|
default:
|
928
799
|
LM_GGML_ABORT("fatal error");
|
929
800
|
}
|
930
801
|
|
931
|
-
|
932
|
-
|
802
|
+
if (gate_exps) {
|
803
|
+
cur = lm_ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
804
|
+
cb(cur, "ffn_moe_gate_par", il);
|
805
|
+
}
|
933
806
|
|
934
|
-
|
807
|
+
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
935
808
|
cb(experts, "ffn_moe_down", il);
|
936
809
|
|
937
810
|
if (!weight_before_ffn) {
|
@@ -974,6 +847,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
|
|
974
847
|
inp->tokens = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, ubatch.n_tokens);
|
975
848
|
//cb(inp->tokens, "inp_tokens", -1);
|
976
849
|
lm_ggml_set_input(inp->tokens);
|
850
|
+
res->t_tokens = inp->tokens;
|
977
851
|
|
978
852
|
cur = lm_ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
979
853
|
|
@@ -1014,11 +888,11 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
|
|
1014
888
|
}
|
1015
889
|
|
1016
890
|
lm_ggml_tensor * llm_graph_context::build_inp_pos() const {
|
1017
|
-
auto inp = std::make_unique<llm_graph_input_pos>(
|
891
|
+
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
1018
892
|
|
1019
893
|
auto & cur = inp->pos;
|
1020
894
|
|
1021
|
-
cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens*
|
895
|
+
cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
1022
896
|
lm_ggml_set_input(cur);
|
1023
897
|
|
1024
898
|
res->add_input(std::move(inp));
|
@@ -1027,11 +901,12 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos() const {
|
|
1027
901
|
}
|
1028
902
|
|
1029
903
|
lm_ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
1030
|
-
auto inp = std::make_unique<llm_graph_input_attn_temp>(
|
904
|
+
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
1031
905
|
|
1032
906
|
auto & cur = inp->attn_scale;
|
1033
907
|
|
1034
|
-
|
908
|
+
// this need to be 1x1xN for broadcasting
|
909
|
+
cur = lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, 1, 1, n_tokens);
|
1035
910
|
lm_ggml_set_input(cur);
|
1036
911
|
|
1037
912
|
res->add_input(std::move(inp));
|
@@ -1079,7 +954,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
1079
954
|
}
|
1080
955
|
|
1081
956
|
lm_ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
1082
|
-
const
|
957
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
1083
958
|
|
1084
959
|
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
1085
960
|
|
@@ -1096,7 +971,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
1096
971
|
}
|
1097
972
|
|
1098
973
|
lm_ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
1099
|
-
const
|
974
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
1100
975
|
|
1101
976
|
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
1102
977
|
|
@@ -1154,7 +1029,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
1154
1029
|
|
1155
1030
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
1156
1031
|
|
1157
|
-
const auto n_kv = kv_self->
|
1032
|
+
const auto n_kv = kv_self->get_n();
|
1158
1033
|
|
1159
1034
|
auto & cur = inp->pos_bucket;
|
1160
1035
|
|
@@ -1188,18 +1063,13 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1188
1063
|
lm_ggml_tensor * v,
|
1189
1064
|
lm_ggml_tensor * kq_b,
|
1190
1065
|
lm_ggml_tensor * kq_mask,
|
1191
|
-
|
1066
|
+
lm_ggml_tensor * v_mla,
|
1192
1067
|
float kq_scale) const {
|
1193
|
-
|
1194
|
-
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
1195
|
-
|
1196
|
-
//const int64_t n_head = hparams.n_head(il);
|
1197
|
-
//const int64_t n_head_kv = hparams.n_head_kv(il);
|
1198
|
-
|
1199
|
-
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
1200
|
-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
1068
|
+
const bool v_trans = v->nb[1] > v->nb[2];
|
1201
1069
|
|
1202
|
-
|
1070
|
+
q = lm_ggml_permute(ctx0, q, 0, 2, 1, 3);
|
1071
|
+
k = lm_ggml_permute(ctx0, k, 0, 2, 1, 3);
|
1072
|
+
v = lm_ggml_permute(ctx0, v, 0, 2, 1, 3);
|
1203
1073
|
|
1204
1074
|
const auto n_tokens = q->ne[1];
|
1205
1075
|
const auto n_head = q->ne[2];
|
@@ -1229,7 +1099,23 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1229
1099
|
|
1230
1100
|
lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
|
1231
1101
|
|
1232
|
-
|
1102
|
+
if (v_mla) {
|
1103
|
+
#if 0
|
1104
|
+
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
1105
|
+
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
|
1106
|
+
cur = lm_ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
1107
|
+
cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
|
1108
|
+
#else
|
1109
|
+
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
|
1110
|
+
// The permutations are noops and only change how the tensor data is interpreted.
|
1111
|
+
cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1112
|
+
cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
|
1113
|
+
cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1114
|
+
cur = lm_ggml_cont(ctx0, cur); // Needed because lm_ggml_reshape_2d expects contiguous inputs.
|
1115
|
+
#endif
|
1116
|
+
}
|
1117
|
+
|
1118
|
+
cur = lm_ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
1233
1119
|
} else {
|
1234
1120
|
lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx0, k, q);
|
1235
1121
|
|
@@ -1267,9 +1153,14 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1267
1153
|
|
1268
1154
|
lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx0, v, kq);
|
1269
1155
|
|
1270
|
-
|
1156
|
+
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
|
1157
|
+
if (v_mla) {
|
1158
|
+
kqv = lm_ggml_mul_mat(ctx0, v_mla, kqv);
|
1159
|
+
}
|
1271
1160
|
|
1272
|
-
cur =
|
1161
|
+
cur = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
1162
|
+
|
1163
|
+
cur = lm_ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
1273
1164
|
|
1274
1165
|
if (!cparams.offload_kqv) {
|
1275
1166
|
// all nodes between the KV store and the attention output are run on the CPU
|
@@ -1304,6 +1195,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1304
1195
|
lm_ggml_tensor * k_cur,
|
1305
1196
|
lm_ggml_tensor * v_cur,
|
1306
1197
|
lm_ggml_tensor * kq_b,
|
1198
|
+
lm_ggml_tensor * v_mla,
|
1307
1199
|
float kq_scale,
|
1308
1200
|
int il) const {
|
1309
1201
|
LM_GGML_UNUSED(n_tokens);
|
@@ -1316,17 +1208,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1316
1208
|
|
1317
1209
|
const auto & kq_mask = inp->get_kq_mask();
|
1318
1210
|
|
1319
|
-
lm_ggml_tensor * q =
|
1320
|
-
|
1321
|
-
|
1322
|
-
lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
1323
|
-
//cb(k, "k", il);
|
1324
|
-
|
1325
|
-
lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
1326
|
-
//cb(k, "v", il);
|
1327
|
-
|
1328
|
-
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
|
1211
|
+
lm_ggml_tensor * q = q_cur;
|
1212
|
+
lm_ggml_tensor * k = k_cur;
|
1213
|
+
lm_ggml_tensor * v = v_cur;
|
1329
1214
|
|
1215
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1330
1216
|
cb(cur, "kqv_out", il);
|
1331
1217
|
|
1332
1218
|
if (wo) {
|
@@ -1349,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
1349
1235
|
|
1350
1236
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
1351
1237
|
|
1352
|
-
|
1353
|
-
|
1354
|
-
inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1355
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1356
|
-
lm_ggml_set_input(inp->self_kq_mask);
|
1357
|
-
|
1358
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
|
1238
|
+
{
|
1239
|
+
LM_GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
1359
1240
|
|
1360
|
-
|
1361
|
-
LM_GGML_ASSERT(hparams.n_swa > 0);
|
1241
|
+
const auto n_kv = kv_self->get_n();
|
1362
1242
|
|
1363
|
-
inp->
|
1364
|
-
//cb(inp->
|
1365
|
-
lm_ggml_set_input(inp->
|
1243
|
+
inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1244
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1245
|
+
lm_ggml_set_input(inp->self_kq_mask);
|
1366
1246
|
|
1367
|
-
inp->
|
1247
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
|
1368
1248
|
}
|
1369
1249
|
|
1370
1250
|
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
@@ -1379,6 +1259,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1379
1259
|
lm_ggml_tensor * k_cur,
|
1380
1260
|
lm_ggml_tensor * v_cur,
|
1381
1261
|
lm_ggml_tensor * kq_b,
|
1262
|
+
lm_ggml_tensor * v_mla,
|
1382
1263
|
float kq_scale,
|
1383
1264
|
int il) const {
|
1384
1265
|
// these nodes are added to the graph together so that they are not reordered
|
@@ -1388,87 +1269,108 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1388
1269
|
lm_ggml_build_forward_expand(gf, v_cur);
|
1389
1270
|
|
1390
1271
|
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
1391
|
-
const auto & n_ctx = cparams.n_ctx;
|
1392
1272
|
|
1393
|
-
|
1394
|
-
|
1273
|
+
// store to KV cache
|
1274
|
+
{
|
1275
|
+
lm_ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
|
1276
|
+
lm_ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
|
1277
|
+
}
|
1395
1278
|
|
1396
|
-
const auto
|
1279
|
+
const auto & kq_mask = inp->get_kq_mask();
|
1397
1280
|
|
1398
|
-
|
1281
|
+
lm_ggml_tensor * q = q_cur;
|
1282
|
+
lm_ggml_tensor * k = kv_self->get_k(ctx0, il);
|
1283
|
+
lm_ggml_tensor * v = kv_self->get_v(ctx0, il);
|
1399
1284
|
|
1400
|
-
|
1401
|
-
|
1402
|
-
LM_GGML_ASSERT(!kv_self->recurrent);
|
1285
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1286
|
+
cb(cur, "kqv_out", il);
|
1403
1287
|
|
1404
|
-
|
1288
|
+
if (wo) {
|
1289
|
+
cur = build_lora_mm(wo, cur);
|
1290
|
+
}
|
1405
1291
|
|
1406
|
-
|
1292
|
+
if (wo_b) {
|
1293
|
+
cur = lm_ggml_add(ctx0, cur, wo_b);
|
1294
|
+
}
|
1407
1295
|
|
1408
|
-
|
1409
|
-
|
1296
|
+
return cur;
|
1297
|
+
}
|
1410
1298
|
|
1411
|
-
|
1412
|
-
|
1299
|
+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
1300
|
+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
1413
1301
|
|
1414
|
-
|
1302
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
1415
1303
|
|
1416
|
-
|
1304
|
+
{
|
1305
|
+
const auto n_kv = kv_self->get_kv_base()->get_n();
|
1417
1306
|
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
// note: the V cache is transposed when not using flash attention
|
1422
|
-
v_cache_view = lm_ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
|
1423
|
-
( n_ctx)*lm_ggml_element_size(kv_self->v_l[il]),
|
1424
|
-
(kv_head)*lm_ggml_element_size(kv_self->v_l[il]));
|
1307
|
+
inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1308
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1309
|
+
lm_ggml_set_input(inp->self_kq_mask);
|
1425
1310
|
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1311
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
|
1312
|
+
}
|
1313
|
+
|
1314
|
+
{
|
1315
|
+
LM_GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
1316
|
+
|
1317
|
+
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
1318
|
+
|
1319
|
+
inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1320
|
+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
1321
|
+
lm_ggml_set_input(inp->self_kq_mask_swa);
|
1429
1322
|
|
1430
|
-
|
1323
|
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
1431
1324
|
}
|
1432
1325
|
|
1326
|
+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
1327
|
+
}
|
1328
|
+
|
1329
|
+
lm_ggml_tensor * llm_graph_context::build_attn(
|
1330
|
+
llm_graph_input_attn_kv_unified_iswa * inp,
|
1331
|
+
lm_ggml_cgraph * gf,
|
1332
|
+
lm_ggml_tensor * wo,
|
1333
|
+
lm_ggml_tensor * wo_b,
|
1334
|
+
lm_ggml_tensor * q_cur,
|
1335
|
+
lm_ggml_tensor * k_cur,
|
1336
|
+
lm_ggml_tensor * v_cur,
|
1337
|
+
lm_ggml_tensor * kq_b,
|
1338
|
+
lm_ggml_tensor * v_mla,
|
1339
|
+
float kq_scale,
|
1340
|
+
int il) const {
|
1341
|
+
// these nodes are added to the graph together so that they are not reordered
|
1342
|
+
// by doing so, the number of splits in the graph is reduced
|
1343
|
+
lm_ggml_build_forward_expand(gf, q_cur);
|
1344
|
+
lm_ggml_build_forward_expand(gf, k_cur);
|
1345
|
+
lm_ggml_build_forward_expand(gf, v_cur);
|
1346
|
+
|
1433
1347
|
const bool is_swa = hparams.is_swa(il);
|
1434
1348
|
|
1349
|
+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
1350
|
+
|
1351
|
+
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
|
1352
|
+
|
1353
|
+
// store to KV cache
|
1354
|
+
{
|
1355
|
+
lm_ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
|
1356
|
+
lm_ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
|
1357
|
+
}
|
1358
|
+
|
1435
1359
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
1436
1360
|
|
1437
|
-
|
1361
|
+
lm_ggml_tensor * q = q_cur;
|
1362
|
+
lm_ggml_tensor * k = kv->get_k(ctx0, il);
|
1363
|
+
lm_ggml_tensor * v = kv->get_v(ctx0, il);
|
1438
1364
|
|
1439
|
-
|
1440
|
-
|
1441
|
-
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
1442
|
-
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
1443
|
-
|
1444
|
-
lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
1445
|
-
//cb(q, "q", il);
|
1446
|
-
|
1447
|
-
lm_ggml_tensor * k =
|
1448
|
-
lm_ggml_view_3d(ctx0, kv_self->k_l[il],
|
1449
|
-
n_embd_head_k, n_kv, n_head_kv,
|
1450
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
1451
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
1452
|
-
0);
|
1453
|
-
//cb(k, "k", il);
|
1454
|
-
|
1455
|
-
lm_ggml_tensor * v = !v_trans ?
|
1456
|
-
lm_ggml_view_3d(ctx0, kv_self->v_l[il],
|
1457
|
-
n_embd_head_v, n_kv, n_head_kv,
|
1458
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
1459
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
|
1460
|
-
0) :
|
1461
|
-
lm_ggml_view_3d(ctx0, kv_self->v_l[il],
|
1462
|
-
n_kv, n_embd_head_v, n_head_kv,
|
1463
|
-
lm_ggml_element_size(kv_self->v_l[il])*n_ctx,
|
1464
|
-
lm_ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
1465
|
-
0);
|
1466
|
-
|
1467
|
-
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
|
1365
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1468
1366
|
cb(cur, "kqv_out", il);
|
1469
1367
|
|
1470
1368
|
if (wo) {
|
1471
1369
|
cur = build_lora_mm(wo, cur);
|
1370
|
+
if (arch == LLM_ARCH_GLM4) {
|
1371
|
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
1372
|
+
lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
|
1373
|
+
}
|
1472
1374
|
}
|
1473
1375
|
|
1474
1376
|
if (wo_b) {
|
@@ -1504,6 +1406,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1504
1406
|
lm_ggml_tensor * k_cur,
|
1505
1407
|
lm_ggml_tensor * v_cur,
|
1506
1408
|
lm_ggml_tensor * kq_b,
|
1409
|
+
lm_ggml_tensor * v_mla,
|
1507
1410
|
float kq_scale,
|
1508
1411
|
int il) const {
|
1509
1412
|
// these nodes are added to the graph together so that they are not reordered
|
@@ -1514,17 +1417,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1514
1417
|
|
1515
1418
|
const auto & kq_mask = inp->get_kq_mask_cross();
|
1516
1419
|
|
1517
|
-
lm_ggml_tensor * q =
|
1518
|
-
|
1519
|
-
|
1520
|
-
lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
1521
|
-
//cb(k, "k", il);
|
1522
|
-
|
1523
|
-
lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
1524
|
-
//cb(k, "v", il);
|
1525
|
-
|
1526
|
-
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
|
1420
|
+
lm_ggml_tensor * q = q_cur;
|
1421
|
+
lm_ggml_tensor * k = k_cur;
|
1422
|
+
lm_ggml_tensor * v = v_cur;
|
1527
1423
|
|
1424
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1528
1425
|
cb(cur, "kqv_out", il);
|
1529
1426
|
|
1530
1427
|
if (wo) {
|
@@ -1549,7 +1446,7 @@ lm_ggml_tensor * llm_graph_context::build_copy_mask_state(
|
|
1549
1446
|
lm_ggml_tensor * state_mask,
|
1550
1447
|
int32_t n_state,
|
1551
1448
|
int32_t n_seqs) const {
|
1552
|
-
const
|
1449
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
1553
1450
|
|
1554
1451
|
const auto n_kv = kv_self->n;
|
1555
1452
|
const auto kv_head = kv_self->head;
|
@@ -1581,7 +1478,7 @@ lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
1581
1478
|
lm_ggml_tensor * state_mask,
|
1582
1479
|
const llama_ubatch & ubatch,
|
1583
1480
|
int il) const {
|
1584
|
-
const
|
1481
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
1585
1482
|
|
1586
1483
|
const auto token_shift_count = hparams.token_shift_count;
|
1587
1484
|
|
@@ -1602,7 +1499,7 @@ lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
1602
1499
|
lm_ggml_tensor * token_shift,
|
1603
1500
|
const llama_ubatch & ubatch,
|
1604
1501
|
int il) const {
|
1605
|
-
const
|
1502
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
1606
1503
|
|
1607
1504
|
const auto token_shift_count = hparams.token_shift_count;
|
1608
1505
|
const auto n_embd = hparams.n_embd;
|
@@ -1693,3 +1590,29 @@ void llm_graph_context::build_pooling(
|
|
1693
1590
|
lm_ggml_build_forward_expand(gf, cur);
|
1694
1591
|
}
|
1695
1592
|
|
1593
|
+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
1594
|
+
// TODO move to hparams if a T5 variant appears that uses a different value
|
1595
|
+
const int64_t max_distance = 128;
|
1596
|
+
|
1597
|
+
if (bidirectional) {
|
1598
|
+
n_buckets >>= 1;
|
1599
|
+
}
|
1600
|
+
|
1601
|
+
const int64_t max_exact = n_buckets >> 1;
|
1602
|
+
|
1603
|
+
int32_t relative_position = x - y;
|
1604
|
+
int32_t relative_bucket = 0;
|
1605
|
+
|
1606
|
+
if (bidirectional) {
|
1607
|
+
relative_bucket += (relative_position > 0) * n_buckets;
|
1608
|
+
relative_position = abs(relative_position);
|
1609
|
+
} else {
|
1610
|
+
relative_position = -std::min<int32_t>(relative_position, 0);
|
1611
|
+
}
|
1612
|
+
|
1613
|
+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
1614
|
+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
1615
|
+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
1616
|
+
|
1617
|
+
return relative_bucket;
|
1618
|
+
}
|