@novastera-oss/llamarn 0.2.5 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +3 -2
- package/android/CMakeLists.txt +6 -3
- package/android/src/main/cpp/include/llama.h +140 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +48 -67
- package/cpp/LlamaCppModel.h +8 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +33 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
- package/cpp/llama.cpp/common/arg.cpp +38 -12
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
- package/cpp/llama.cpp/common/chat-parser.h +4 -1
- package/cpp/llama.cpp/common/chat.cpp +16 -13
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +52 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/json-partial.cpp +5 -4
- package/cpp/llama.cpp/common/json-partial.h +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
- package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +140 -38
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
- package/cpp/llama.cpp/src/llama-batch.h +47 -17
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +488 -313
- package/cpp/llama.cpp/src/llama-context.h +38 -17
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
- package/cpp/llama.cpp/src/llama-graph.h +109 -52
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
- package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +89 -4
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +735 -143
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
- package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
- package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
- package/cpp/rn-completion.cpp +65 -10
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common/minja/chat-template.hpp +1 -1
- package/ios/include/common/minja/minja.hpp +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/json-schema-to-grammar.h +4 -4
- package/ios/include/llama.h +140 -38
- package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
- package/ios/libs/llama.xcframework/Info.plist +20 -20
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -1,8 +1,14 @@
|
|
|
1
1
|
#include "llama-batch.h"
|
|
2
2
|
|
|
3
|
+
#include "llama-impl.h"
|
|
4
|
+
#include "llama-cparams.h"
|
|
5
|
+
#include "llama-vocab.h"
|
|
6
|
+
#include "llama-memory.h"
|
|
7
|
+
|
|
3
8
|
#include <cassert>
|
|
4
9
|
#include <cstring>
|
|
5
10
|
#include <algorithm>
|
|
11
|
+
#include <sstream>
|
|
6
12
|
|
|
7
13
|
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
|
8
14
|
// clear empty sequences
|
|
@@ -15,24 +21,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
|
|
15
21
|
break;
|
|
16
22
|
}
|
|
17
23
|
}
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
+
|
|
25
|
+
udatas.push_back({});
|
|
26
|
+
|
|
27
|
+
auto & udata = udatas.back();
|
|
28
|
+
|
|
29
|
+
udata.token.resize(!has_embd ? n_ubatch : 0);
|
|
30
|
+
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
|
31
|
+
udata.pos.resize(n_ubatch);
|
|
32
|
+
udata.n_seq_id.resize(n_ubatch);
|
|
33
|
+
udata.seq_id.resize(n_ubatch);
|
|
34
|
+
udata.output.resize(n_ubatch);
|
|
35
|
+
|
|
24
36
|
llama_ubatch ubatch = {
|
|
25
37
|
/*equal_seqs =*/ true,
|
|
26
38
|
/*n_tokens =*/ 0,
|
|
27
39
|
/*n_seq_tokens =*/ 0,
|
|
28
40
|
/*n_seqs =*/ 0,
|
|
29
|
-
/*token =*/ !has_embd ?
|
|
30
|
-
/*embd =*/ has_embd ?
|
|
31
|
-
/*pos =*/
|
|
32
|
-
/*n_seq_id =*/
|
|
33
|
-
/*seq_id =*/
|
|
34
|
-
/*output =*/
|
|
41
|
+
/*token =*/ !has_embd ? udata.token.data() : nullptr,
|
|
42
|
+
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
|
|
43
|
+
/*pos =*/ udata.pos.data(),
|
|
44
|
+
/*n_seq_id =*/ udata.n_seq_id.data(),
|
|
45
|
+
/*seq_id =*/ udata.seq_id.data(),
|
|
46
|
+
/*output =*/ udata.output.data(),
|
|
35
47
|
};
|
|
48
|
+
|
|
36
49
|
return ubatch;
|
|
37
50
|
}
|
|
38
51
|
|
|
@@ -98,12 +111,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|
|
98
111
|
ubatch.seq_id = batch->seq_id + seq.offset;
|
|
99
112
|
}
|
|
100
113
|
}
|
|
101
|
-
if (
|
|
102
|
-
for (size_t i = 0; i < length; ++i) {
|
|
103
|
-
ubatch.output[ubatch.n_tokens + i] = 1;
|
|
104
|
-
out_ids.push_back(ids[seq.offset + i]);
|
|
105
|
-
}
|
|
106
|
-
} else if (batch->logits) {
|
|
114
|
+
if (batch->logits) {
|
|
107
115
|
if (ubatch.equal_seqs) {
|
|
108
116
|
for (size_t i = 0; i < length; ++i) {
|
|
109
117
|
size_t id = ids[seq.offset + i];
|
|
@@ -190,11 +198,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
|
|
190
198
|
return ubatch;
|
|
191
199
|
}
|
|
192
200
|
|
|
193
|
-
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split
|
|
201
|
+
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
|
|
194
202
|
GGML_ASSERT(batch.n_tokens >= 0);
|
|
195
203
|
this->batch = &batch;
|
|
196
204
|
this->n_embd = n_embd;
|
|
197
|
-
this->logits_all = logits_all;
|
|
198
205
|
|
|
199
206
|
n_tokens = batch.n_tokens;
|
|
200
207
|
ids.resize(n_tokens);
|
|
@@ -278,17 +285,56 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
|
|
278
285
|
);
|
|
279
286
|
}
|
|
280
287
|
|
|
281
|
-
llama_batch_allocr::llama_batch_allocr(
|
|
282
|
-
|
|
288
|
+
llama_batch_allocr::llama_batch_allocr() {
|
|
289
|
+
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
|
290
|
+
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
|
291
|
+
|
|
292
|
+
seq_pos.resize(LLAMA_MAX_SEQ);
|
|
293
|
+
seq_cpl.resize(LLAMA_MAX_SEQ);
|
|
294
|
+
for (auto & cur : seq_cpl) {
|
|
295
|
+
cur.resize(LLAMA_MAX_SEQ);
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
bool llama_batch_allocr::init(
|
|
300
|
+
const llama_batch & batch_inp,
|
|
301
|
+
const llama_vocab & vocab,
|
|
302
|
+
const llama_memory_i * memory,
|
|
303
|
+
bool embd_all) {
|
|
304
|
+
clear();
|
|
305
|
+
|
|
306
|
+
batch = batch_inp;
|
|
307
|
+
|
|
283
308
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
309
|
+
|
|
310
|
+
//
|
|
311
|
+
// validate input batch
|
|
312
|
+
//
|
|
313
|
+
|
|
314
|
+
if (batch.token) {
|
|
315
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
316
|
+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
|
317
|
+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
318
|
+
return false;
|
|
319
|
+
}
|
|
289
320
|
}
|
|
290
|
-
batch.pos = pos.data();
|
|
291
321
|
}
|
|
322
|
+
|
|
323
|
+
if (batch.seq_id) {
|
|
324
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
325
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
326
|
+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
|
327
|
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
|
328
|
+
return false;
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
//
|
|
335
|
+
// auto-generate missing fields
|
|
336
|
+
//
|
|
337
|
+
|
|
292
338
|
if (!batch.n_seq_id) {
|
|
293
339
|
n_seq_id.resize(batch.n_tokens);
|
|
294
340
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
@@ -296,6 +342,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
|
|
296
342
|
}
|
|
297
343
|
batch.n_seq_id = n_seq_id.data();
|
|
298
344
|
}
|
|
345
|
+
|
|
299
346
|
if (!batch.seq_id) {
|
|
300
347
|
seq_id.resize(batch.n_tokens + 1);
|
|
301
348
|
seq_id[batch.n_tokens] = NULL;
|
|
@@ -304,10 +351,221 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
|
|
304
351
|
}
|
|
305
352
|
batch.seq_id = seq_id.data();
|
|
306
353
|
}
|
|
354
|
+
|
|
355
|
+
if (!batch.pos) {
|
|
356
|
+
pos.resize(batch.n_tokens);
|
|
357
|
+
|
|
358
|
+
// initialize the starting position for each sequence based on the positions in the memory
|
|
359
|
+
llama_pos p0[LLAMA_MAX_SEQ];
|
|
360
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
361
|
+
if (!memory) {
|
|
362
|
+
p0[s] = 0;
|
|
363
|
+
} else {
|
|
364
|
+
p0[s] = memory->seq_pos_max(s) + 1;
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
369
|
+
const llama_seq_id seq_id = batch.seq_id[i][0];
|
|
370
|
+
|
|
371
|
+
pos[i] = p0[seq_id];
|
|
372
|
+
|
|
373
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
374
|
+
p0[batch.seq_id[i][s]] = pos[i] + 1;
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
batch.pos = pos.data();
|
|
379
|
+
}
|
|
380
|
+
|
|
307
381
|
if (!batch.logits) {
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
382
|
+
if (embd_all) {
|
|
383
|
+
// return the output for all tokens
|
|
384
|
+
output.resize(batch.n_tokens, true);
|
|
385
|
+
} else {
|
|
386
|
+
// return the output only for the last token
|
|
387
|
+
output.resize(batch.n_tokens, false);
|
|
388
|
+
output[output.size() - 1] = true;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
batch.logits = output.data();
|
|
392
|
+
} else if (embd_all) {
|
|
393
|
+
bool warn = false;
|
|
394
|
+
|
|
395
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
396
|
+
if (batch.logits[i] == 0) {
|
|
397
|
+
warn = true;
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
if (warn) {
|
|
402
|
+
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
|
|
403
|
+
|
|
404
|
+
output.resize(batch.n_tokens, true);
|
|
405
|
+
batch.logits = output.data();
|
|
406
|
+
}
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
//
|
|
410
|
+
// compute stats
|
|
411
|
+
//
|
|
412
|
+
|
|
413
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
414
|
+
n_outputs += batch.logits[i] != 0;
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
// determine coupled sequences
|
|
418
|
+
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
|
419
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
420
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
421
|
+
seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
|
|
422
|
+
|
|
423
|
+
if (s > 0) {
|
|
424
|
+
const llama_seq_id s0 = batch.seq_id[i][0];
|
|
425
|
+
const llama_seq_id s1 = batch.seq_id[i][s];
|
|
426
|
+
|
|
427
|
+
// mark that sequence s1 is coupled to s0
|
|
428
|
+
seq_cpl[s1][s0] = true;
|
|
429
|
+
|
|
430
|
+
// note: the other way around is not necessary for now
|
|
431
|
+
//seq_cpl[s0][s1] = true;
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
if (debug > 0) {
|
|
437
|
+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
|
438
|
+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
|
|
439
|
+
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
|
|
440
|
+
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
|
|
441
|
+
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
|
|
442
|
+
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
|
|
443
|
+
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
|
|
444
|
+
LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
|
|
445
|
+
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
|
|
446
|
+
|
|
447
|
+
if (debug > 1) {
|
|
448
|
+
int seq_id_max = 0;
|
|
449
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
450
|
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
451
|
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
452
|
+
seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
++seq_id_max;
|
|
457
|
+
|
|
458
|
+
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
|
|
459
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
460
|
+
std::vector<int8_t> seq_id(seq_id_max);
|
|
461
|
+
|
|
462
|
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
463
|
+
seq_id[batch.seq_id[i][s]] = 1;
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
std::stringstream ss;
|
|
467
|
+
for (int s = 0; s < seq_id_max; ++s) {
|
|
468
|
+
if (seq_id[s]) {
|
|
469
|
+
ss << s%10;
|
|
470
|
+
} else {
|
|
471
|
+
ss << ".";
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
|
476
|
+
__func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
|
|
477
|
+
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
|
|
478
|
+
}
|
|
479
|
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
480
|
+
|
|
481
|
+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
|
|
482
|
+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
|
|
483
|
+
if (seq_pos[s0].empty()) {
|
|
484
|
+
continue;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
std::stringstream ss;
|
|
488
|
+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
|
|
489
|
+
if (seq_cpl[s0][s1]) {
|
|
490
|
+
ss << s1 << " ";
|
|
491
|
+
}
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
|
495
|
+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
|
496
|
+
}
|
|
497
|
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
498
|
+
}
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
//
|
|
502
|
+
// consistency checks
|
|
503
|
+
//
|
|
504
|
+
|
|
505
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
506
|
+
if (seq_pos[s].empty()) {
|
|
507
|
+
continue;
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
|
|
511
|
+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
|
|
512
|
+
return false;
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
|
516
|
+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
|
517
|
+
return false;
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
if (memory) {
|
|
522
|
+
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
|
523
|
+
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
|
524
|
+
if (seq_cpl[s0][s1]) {
|
|
525
|
+
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
|
526
|
+
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
|
527
|
+
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
|
|
528
|
+
return false;
|
|
529
|
+
}
|
|
530
|
+
}
|
|
531
|
+
}
|
|
532
|
+
}
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
return true;
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
const llama_batch & llama_batch_allocr::get_batch() const {
|
|
539
|
+
return batch;
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
uint32_t llama_batch_allocr::get_n_outputs() const {
|
|
543
|
+
return n_outputs;
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
|
|
547
|
+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|
551
|
+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
void llama_batch_allocr::clear() {
|
|
555
|
+
n_outputs = 0;
|
|
556
|
+
|
|
557
|
+
batch = {};
|
|
558
|
+
pos.clear();
|
|
559
|
+
n_seq_id.clear();
|
|
560
|
+
seq_id.clear();
|
|
561
|
+
output.clear();
|
|
562
|
+
|
|
563
|
+
for (auto & cur : seq_pos) {
|
|
564
|
+
cur.clear();
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
for (auto & cur : seq_cpl) {
|
|
568
|
+
std::fill(cur.begin(), cur.end(), false);
|
|
311
569
|
}
|
|
312
570
|
}
|
|
313
571
|
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
#include <array>
|
|
6
6
|
#include <vector>
|
|
7
|
+
#include <set>
|
|
7
8
|
|
|
8
9
|
// very similar to llama_batch,
|
|
9
10
|
// but has more metadata about sequences
|
|
@@ -11,7 +12,7 @@ struct llama_ubatch {
|
|
|
11
12
|
bool equal_seqs;
|
|
12
13
|
// TODO: whole_seqs for embeddings?
|
|
13
14
|
|
|
14
|
-
uint32_t n_tokens;
|
|
15
|
+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
|
15
16
|
uint32_t n_seq_tokens; // tokens per sequence
|
|
16
17
|
uint32_t n_seqs;
|
|
17
18
|
|
|
@@ -39,8 +40,6 @@ struct llama_sbatch {
|
|
|
39
40
|
|
|
40
41
|
size_t n_embd;
|
|
41
42
|
|
|
42
|
-
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
|
43
|
-
|
|
44
43
|
// sorted indices into the batch
|
|
45
44
|
std::vector<int64_t> ids;
|
|
46
45
|
// batch indices of the output
|
|
@@ -49,13 +48,18 @@ struct llama_sbatch {
|
|
|
49
48
|
|
|
50
49
|
const llama_batch * batch = nullptr;
|
|
51
50
|
|
|
52
|
-
// buffers for the
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
51
|
+
// buffers for the ubatches
|
|
52
|
+
// TODO: very hacky, this needs a complete rework
|
|
53
|
+
struct ubatch_data {
|
|
54
|
+
std::vector<llama_token> token;
|
|
55
|
+
std::vector<float> embd;
|
|
56
|
+
std::vector<llama_pos> pos;
|
|
57
|
+
std::vector<int32_t> n_seq_id;
|
|
58
|
+
std::vector<llama_seq_id *> seq_id;
|
|
59
|
+
std::vector<int8_t> output;
|
|
60
|
+
};
|
|
61
|
+
|
|
62
|
+
std::vector<ubatch_data> udatas;
|
|
59
63
|
|
|
60
64
|
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
|
61
65
|
|
|
@@ -71,19 +75,45 @@ struct llama_sbatch {
|
|
|
71
75
|
llama_ubatch split_seq(size_t n_ubatch);
|
|
72
76
|
|
|
73
77
|
llama_sbatch() = default;
|
|
74
|
-
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false
|
|
78
|
+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
|
|
75
79
|
};
|
|
76
80
|
|
|
77
|
-
//
|
|
78
|
-
|
|
79
|
-
|
|
81
|
+
// a helper for sanitizing and fulfilling a batch
|
|
82
|
+
class llama_batch_allocr {
|
|
83
|
+
public:
|
|
84
|
+
llama_batch_allocr();
|
|
85
|
+
|
|
86
|
+
// sanitize and auto-gen missing data in the input batch
|
|
87
|
+
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
|
88
|
+
bool init(
|
|
89
|
+
const llama_batch & batch_inp,
|
|
90
|
+
const llama_vocab & vocab,
|
|
91
|
+
const llama_memory_i * memory,
|
|
92
|
+
bool embd_all);
|
|
93
|
+
|
|
94
|
+
const llama_batch & get_batch() const;
|
|
95
|
+
|
|
96
|
+
uint32_t get_n_outputs() const;
|
|
97
|
+
|
|
98
|
+
llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
|
99
|
+
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
|
100
|
+
|
|
101
|
+
private:
|
|
102
|
+
void clear();
|
|
103
|
+
|
|
104
|
+
llama_batch batch;
|
|
105
|
+
|
|
106
|
+
uint32_t n_outputs;
|
|
80
107
|
|
|
81
108
|
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
109
|
+
|
|
82
110
|
std::vector<llama_pos> pos;
|
|
83
111
|
std::vector<int32_t> n_seq_id;
|
|
84
112
|
std::vector<llama_seq_id *> seq_id;
|
|
85
|
-
std::vector<int8_t>
|
|
113
|
+
std::vector<int8_t> output;
|
|
114
|
+
|
|
115
|
+
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
|
116
|
+
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
|
86
117
|
|
|
87
|
-
|
|
88
|
-
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
|
118
|
+
int debug;
|
|
89
119
|
};
|
|
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
|
183
183
|
return LLM_CHAT_TEMPLATE_BAILING;
|
|
184
184
|
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
|
185
185
|
return LLM_CHAT_TEMPLATE_LLAMA4;
|
|
186
|
+
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
|
187
|
+
return LLM_CHAT_TEMPLATE_DOTS1;
|
|
186
188
|
}
|
|
187
189
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
|
188
190
|
}
|
|
@@ -331,7 +333,7 @@ int32_t llm_chat_apply_template(
|
|
|
331
333
|
std::string role(message->role);
|
|
332
334
|
if (role == "system") {
|
|
333
335
|
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
|
|
334
|
-
system_prompt
|
|
336
|
+
system_prompt += trim(message->content);
|
|
335
337
|
continue;
|
|
336
338
|
}
|
|
337
339
|
// in gemma, "assistant" is "model"
|
|
@@ -353,7 +355,7 @@ int32_t llm_chat_apply_template(
|
|
|
353
355
|
std::string role(message->role);
|
|
354
356
|
if (role == "system") {
|
|
355
357
|
// there is no system message support, we will merge it with user prompt
|
|
356
|
-
system_prompt
|
|
358
|
+
system_prompt += message->content;
|
|
357
359
|
continue;
|
|
358
360
|
} else if (role == "user") {
|
|
359
361
|
ss << "Human: ";
|
|
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
|
|
|
643
645
|
if (add_ass) {
|
|
644
646
|
ss << "Assistant:";
|
|
645
647
|
}
|
|
648
|
+
} else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
|
|
649
|
+
// dots.llm1.inst (DOTS1)
|
|
650
|
+
for (auto message : chat) {
|
|
651
|
+
std::string role(message->role);
|
|
652
|
+
if (role == "system") {
|
|
653
|
+
ss << "<|system|>" << message->content << "<|endofsystem|>";
|
|
654
|
+
} else if (role == "user") {
|
|
655
|
+
ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
|
|
656
|
+
} else {
|
|
657
|
+
ss << "<|response|>" << message->content << "<|endofresponse|>";
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
if (add_ass) {
|
|
661
|
+
ss << "<|response|>";
|
|
662
|
+
}
|
|
646
663
|
} else {
|
|
647
664
|
// template not supported
|
|
648
665
|
return -1;
|