cui-llama.rn 1.6.1 → 1.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
- package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +38 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/llama-model.h
CHANGED
@@ -76,6 +76,7 @@ enum llm_type {
|
|
76
76
|
LLM_TYPE_236B,
|
77
77
|
LLM_TYPE_290B,
|
78
78
|
LLM_TYPE_314B,
|
79
|
+
LLM_TYPE_405B,
|
79
80
|
LLM_TYPE_671B,
|
80
81
|
LLM_TYPE_SMALL,
|
81
82
|
LLM_TYPE_MEDIUM,
|
@@ -95,6 +96,8 @@ enum llm_type {
|
|
95
96
|
LLM_TYPE_235B_A22B,
|
96
97
|
};
|
97
98
|
|
99
|
+
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
|
100
|
+
|
98
101
|
struct llama_layer_posnet {
|
99
102
|
// resnet
|
100
103
|
struct lm_ggml_tensor * norm1 = nullptr;
|
@@ -395,7 +398,10 @@ struct llama_model {
|
|
395
398
|
|
396
399
|
const struct lm_ggml_tensor * get_tensor(const char * name) const;
|
397
400
|
|
398
|
-
|
401
|
+
float get_rope_freq_base (const llama_cparams & cparams, int il) const;
|
402
|
+
float get_rope_freq_scale(const llama_cparams & cparams, int il) const;
|
403
|
+
|
404
|
+
lm_ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
|
399
405
|
|
400
406
|
// note: can mutate `cparams`
|
401
407
|
// TODO: move this to new llm_arch_model_i interface
|
package/cpp/llama-sampling.cpp
CHANGED
@@ -1751,23 +1751,35 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
|
|
1751
1751
|
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
1752
1752
|
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
1753
1753
|
|
1754
|
+
if (ctx->n <= 0.0f || cur_p->size <= 1) {
|
1755
|
+
return;
|
1756
|
+
}
|
1757
|
+
|
1754
1758
|
// find max logit and calculate mean
|
1755
1759
|
float max = cur_p->data[0].logit;
|
1756
1760
|
float logits_sum = 0;
|
1761
|
+
size_t valid_count = 0;
|
1757
1762
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
1758
|
-
|
1759
|
-
|
1763
|
+
// Only count non-negative infinity values
|
1764
|
+
if (cur_p->data[i].logit != -INFINITY) {
|
1765
|
+
if (cur_p->data[i].logit > max) {
|
1766
|
+
max = cur_p->data[i].logit;
|
1767
|
+
}
|
1768
|
+
logits_sum += cur_p->data[i].logit;
|
1769
|
+
valid_count++;
|
1760
1770
|
}
|
1761
|
-
logits_sum += cur_p->data[i].logit;
|
1762
1771
|
}
|
1763
|
-
float mean = logits_sum/
|
1772
|
+
float mean = valid_count > 0 ? logits_sum/valid_count : 0;
|
1764
1773
|
|
1765
1774
|
// calculate standard deviation
|
1766
1775
|
float acc = 0;
|
1767
1776
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
1768
|
-
|
1777
|
+
// Skip -infinity in std calculation
|
1778
|
+
if (cur_p->data[i].logit != -INFINITY) {
|
1779
|
+
acc += pow(cur_p->data[i].logit - mean, 2);
|
1780
|
+
}
|
1769
1781
|
}
|
1770
|
-
float std = sqrt(acc/
|
1782
|
+
float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
|
1771
1783
|
|
1772
1784
|
//apply mask
|
1773
1785
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
package/cpp/llama-vocab.cpp
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
#include "llama-vocab.h"
|
2
2
|
|
3
|
+
#include "ggml.h"
|
4
|
+
#include "gguf.h"
|
3
5
|
#include "llama-impl.h"
|
4
6
|
#include "llama-model-loader.h"
|
5
7
|
|
@@ -415,6 +417,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|
415
417
|
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
416
418
|
};
|
417
419
|
break;
|
420
|
+
case LLAMA_VOCAB_PRE_TYPE_SEED_CODER:
|
421
|
+
regex_exprs = {
|
422
|
+
// original regex from tokenizer.json
|
423
|
+
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
|
424
|
+
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
425
|
+
};
|
426
|
+
break;
|
418
427
|
default:
|
419
428
|
// default regex for BPE tokenization pre-processing
|
420
429
|
regex_exprs = {
|
@@ -826,7 +835,7 @@ struct llm_tokenizer_ugm_session {
|
|
826
835
|
}
|
827
836
|
|
828
837
|
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
|
829
|
-
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -
|
838
|
+
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
|
830
839
|
// at the beginning tokenization score is zero
|
831
840
|
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
|
832
841
|
|
@@ -858,7 +867,7 @@ struct llm_tokenizer_ugm_session {
|
|
858
867
|
const double challenger_score = current_best.score_sum + token_score;
|
859
868
|
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
860
869
|
if (challenger_score > current_champ.score_sum) {
|
861
|
-
struct best_tokenization challenger = { token_id, input_offset,
|
870
|
+
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
|
862
871
|
current_champ = challenger;
|
863
872
|
}
|
864
873
|
}
|
@@ -872,7 +881,7 @@ struct llm_tokenizer_ugm_session {
|
|
872
881
|
prefix_offset = input_offset + n_utf8_code_units;
|
873
882
|
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
874
883
|
if (challenger_score > current_champ.score_sum) {
|
875
|
-
struct best_tokenization challenger = { vocab.token_unk(), input_offset,
|
884
|
+
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
|
876
885
|
current_champ = challenger;
|
877
886
|
}
|
878
887
|
}
|
@@ -998,7 +1007,7 @@ private:
|
|
998
1007
|
struct best_tokenization {
|
999
1008
|
llama_token token_id;
|
1000
1009
|
size_t input_offset;
|
1001
|
-
|
1010
|
+
double score_sum;
|
1002
1011
|
};
|
1003
1012
|
|
1004
1013
|
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
@@ -1227,6 +1236,9 @@ struct fragment_buffer_variant {
|
|
1227
1236
|
struct llama_vocab::impl {
|
1228
1237
|
uint32_t n_token_types = 0; // for BERT-style token types
|
1229
1238
|
|
1239
|
+
std::string tokenizer_model;
|
1240
|
+
std::string tokenizer_pre;
|
1241
|
+
|
1230
1242
|
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
1231
1243
|
enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
1232
1244
|
|
@@ -1362,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
1362
1374
|
|
1363
1375
|
// determine vocab type
|
1364
1376
|
{
|
1365
|
-
std::string tokenizer_model;
|
1366
|
-
std::string tokenizer_pre;
|
1367
|
-
|
1368
1377
|
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
|
1369
1378
|
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
1370
1379
|
|
@@ -1459,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
1459
1468
|
|
1460
1469
|
const int precompiled_charsmap_keyidx = lm_gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
1461
1470
|
if (precompiled_charsmap_keyidx != -1) {
|
1462
|
-
|
1471
|
+
const lm_gguf_type pc_type = lm_gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
|
1472
|
+
LM_GGML_ASSERT(pc_type == LM_GGUF_TYPE_INT8 || pc_type == LM_GGUF_TYPE_UINT8);
|
1473
|
+
|
1474
|
+
const size_t n_precompiled_charsmap = lm_gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
1463
1475
|
const char * pc = (const char *) lm_gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
1464
1476
|
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
1465
1477
|
#ifdef IS_BIG_ENDIAN
|
@@ -1634,6 +1646,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
1634
1646
|
tokenizer_pre == "bailingmoe") {
|
1635
1647
|
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
1636
1648
|
clean_spaces = false;
|
1649
|
+
} else if (
|
1650
|
+
tokenizer_pre == "seed-coder") {
|
1651
|
+
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
1652
|
+
clean_spaces = false;
|
1637
1653
|
} else {
|
1638
1654
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
1639
1655
|
}
|
@@ -2778,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
2778
2794
|
pimpl->load(ml, kv);
|
2779
2795
|
}
|
2780
2796
|
|
2797
|
+
std::string llama_vocab::get_tokenizer_model() const {
|
2798
|
+
return pimpl->tokenizer_model;
|
2799
|
+
}
|
2800
|
+
|
2801
|
+
std::string llama_vocab::get_tokenizer_pre() const {
|
2802
|
+
return pimpl->tokenizer_pre;
|
2803
|
+
}
|
2804
|
+
|
2781
2805
|
enum llama_vocab_type llama_vocab::get_type() const {
|
2782
2806
|
return pimpl->type;
|
2783
2807
|
}
|
@@ -3000,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
|
|
3000
3024
|
return it->second;
|
3001
3025
|
}
|
3002
3026
|
|
3027
|
+
std::vector<std::string> llama_vocab::get_bpe_merges() const {
|
3028
|
+
std::vector<std::string> result(pimpl->bpe_ranks.size());
|
3029
|
+
|
3030
|
+
for (const auto & pair : pimpl->bpe_ranks) {
|
3031
|
+
result[pair.second] = pair.first.first + " " + pair.first.second;
|
3032
|
+
}
|
3033
|
+
|
3034
|
+
return result;
|
3035
|
+
}
|
3036
|
+
|
3037
|
+
std::vector<char> llama_vocab::get_precompiled_charsmap() const {
|
3038
|
+
return pimpl->precompiled_charsmap;
|
3039
|
+
}
|
3040
|
+
|
3003
3041
|
int32_t llama_vocab::tokenize(
|
3004
3042
|
const char * text,
|
3005
3043
|
int32_t text_len,
|
package/cpp/llama-vocab.h
CHANGED
@@ -21,6 +21,9 @@ struct llama_vocab {
|
|
21
21
|
|
22
22
|
void load(llama_model_loader & ml, const LLM_KV & kv);
|
23
23
|
|
24
|
+
std::string get_tokenizer_model() const;
|
25
|
+
std::string get_tokenizer_pre() const;
|
26
|
+
|
24
27
|
enum llama_vocab_type get_type() const;
|
25
28
|
enum llama_vocab_pre_type get_pre_type() const;
|
26
29
|
|
@@ -80,6 +83,9 @@ struct llama_vocab {
|
|
80
83
|
int max_token_len() const;
|
81
84
|
|
82
85
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
86
|
+
std::vector<std::string> get_bpe_merges() const;
|
87
|
+
|
88
|
+
std::vector<char> get_precompiled_charsmap() const;
|
83
89
|
|
84
90
|
int32_t tokenize(
|
85
91
|
const char * text,
|
package/cpp/llama.cpp
CHANGED
@@ -4,6 +4,7 @@
|
|
4
4
|
#include "llama-mmap.h"
|
5
5
|
#include "llama-vocab.h"
|
6
6
|
#include "llama-model-loader.h"
|
7
|
+
#include "llama-model-saver.h"
|
7
8
|
#include "llama-model.h"
|
8
9
|
|
9
10
|
#include "ggml.h"
|
@@ -150,6 +151,11 @@ static struct llama_model * llama_model_load_from_file_impl(
|
|
150
151
|
struct llama_model_params params) {
|
151
152
|
lm_ggml_time_init();
|
152
153
|
|
154
|
+
if (!params.vocab_only && lm_ggml_backend_reg_count() == 0) {
|
155
|
+
LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use lm_ggml_backend_load() or lm_ggml_backend_load_all() to load a backend before calling this function\n", __func__);
|
156
|
+
return nullptr;
|
157
|
+
}
|
158
|
+
|
153
159
|
unsigned cur_percentage = 0;
|
154
160
|
if (params.progress_callback == NULL) {
|
155
161
|
params.progress_callback_user_data = &cur_percentage;
|
@@ -264,6 +270,13 @@ struct llama_model * llama_model_load_from_splits(
|
|
264
270
|
return llama_model_load_from_file_impl(splits.front(), splits, params);
|
265
271
|
}
|
266
272
|
|
273
|
+
void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
|
274
|
+
llama_model_saver ms(*model);
|
275
|
+
ms.add_kv_from_model();
|
276
|
+
ms.add_tensors_from_model();
|
277
|
+
ms.save(path_model);
|
278
|
+
}
|
279
|
+
|
267
280
|
//
|
268
281
|
// chat templates
|
269
282
|
//
|
@@ -349,3 +362,4 @@ const char * llama_print_system_info(void) {
|
|
349
362
|
|
350
363
|
return s.c_str();
|
351
364
|
}
|
365
|
+
|
package/cpp/llama.h
CHANGED
@@ -4,6 +4,7 @@
|
|
4
4
|
#include "ggml.h"
|
5
5
|
#include "ggml-cpu.h"
|
6
6
|
#include "ggml-backend.h"
|
7
|
+
#include "ggml-opt.h"
|
7
8
|
|
8
9
|
#include <stddef.h>
|
9
10
|
#include <stdint.h>
|
@@ -113,6 +114,7 @@ extern "C" {
|
|
113
114
|
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
114
115
|
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
115
116
|
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
117
|
+
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
116
118
|
};
|
117
119
|
|
118
120
|
enum llama_rope_type {
|
@@ -344,7 +346,7 @@ extern "C" {
|
|
344
346
|
float yarn_beta_fast; // YaRN low correction dim
|
345
347
|
float yarn_beta_slow; // YaRN high correction dim
|
346
348
|
uint32_t yarn_orig_ctx; // YaRN original context size
|
347
|
-
float defrag_thold; // defragment the KV cache if holes/size > thold,
|
349
|
+
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
348
350
|
|
349
351
|
lm_ggml_backend_sched_eval_callback cb_eval;
|
350
352
|
void * cb_eval_user_data;
|
@@ -352,19 +354,19 @@ extern "C" {
|
|
352
354
|
enum lm_ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
353
355
|
enum lm_ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
354
356
|
|
355
|
-
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
356
|
-
// TODO: move at the end of the struct
|
357
|
-
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
358
|
-
bool embeddings; // if true, extract embeddings (together with logits)
|
359
|
-
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
360
|
-
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
361
|
-
bool no_perf; // whether to measure performance timings
|
362
|
-
|
363
357
|
// Abort callback
|
364
358
|
// if it returns true, execution of llama_decode() will be aborted
|
365
359
|
// currently works only with CPU execution
|
366
360
|
lm_ggml_abort_callback abort_callback;
|
367
361
|
void * abort_callback_data;
|
362
|
+
|
363
|
+
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
364
|
+
bool embeddings; // if true, extract embeddings (together with logits)
|
365
|
+
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
|
366
|
+
bool flash_attn; // use flash attention [EXPERIMENTAL]
|
367
|
+
bool no_perf; // measure performance timings
|
368
|
+
bool op_offload; // offload host tensor operations to device
|
369
|
+
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
368
370
|
};
|
369
371
|
|
370
372
|
// model quantization parameters
|
@@ -446,6 +448,10 @@ extern "C" {
|
|
446
448
|
size_t n_paths,
|
447
449
|
struct llama_model_params params);
|
448
450
|
|
451
|
+
LLAMA_API void llama_model_save_to_file(
|
452
|
+
const struct llama_model * model,
|
453
|
+
const char * path_model);
|
454
|
+
|
449
455
|
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
|
450
456
|
"use llama_model_free instead");
|
451
457
|
|
@@ -603,71 +609,14 @@ extern "C" {
|
|
603
609
|
// KV cache
|
604
610
|
//
|
605
611
|
|
606
|
-
// TODO: start using struct llama_kv_cache
|
607
|
-
|
608
|
-
// Information associated with an individual cell in the KV cache view.
|
609
|
-
struct llama_kv_cache_view_cell {
|
610
|
-
// The position for this cell. Takes KV cache shifts into account.
|
611
|
-
// May be negative if the cell is not populated.
|
612
|
-
llama_pos pos;
|
613
|
-
};
|
614
|
-
|
615
|
-
// An updateable view of the KV cache.
|
616
|
-
struct llama_kv_cache_view {
|
617
|
-
// Number of KV cache cells. This will be the same as the context size.
|
618
|
-
int32_t n_cells;
|
619
|
-
|
620
|
-
// Maximum number of sequences that can exist in a cell. It's not an error
|
621
|
-
// if there are more sequences in a cell than this value, however they will
|
622
|
-
// not be visible in the view cells_sequences.
|
623
|
-
int32_t n_seq_max;
|
624
|
-
|
625
|
-
// Number of tokens in the cache. For example, if there are two populated
|
626
|
-
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
627
|
-
// ids then you'll have 3 tokens.
|
628
|
-
int32_t token_count;
|
629
|
-
|
630
|
-
// Number of populated cache cells.
|
631
|
-
int32_t used_cells;
|
632
|
-
|
633
|
-
// Maximum contiguous empty slots in the cache.
|
634
|
-
int32_t max_contiguous;
|
635
|
-
|
636
|
-
// Index to the start of the max_contiguous slot range. Can be negative
|
637
|
-
// when cache is full.
|
638
|
-
int32_t max_contiguous_idx;
|
639
|
-
|
640
|
-
// Information for an individual cell.
|
641
|
-
struct llama_kv_cache_view_cell * cells;
|
642
|
-
|
643
|
-
// The sequences for each cell. There will be n_seq_max items per cell.
|
644
|
-
llama_seq_id * cells_sequences;
|
645
|
-
};
|
646
|
-
|
647
|
-
// Create an empty KV cache view. (use only for debugging purposes)
|
648
|
-
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
649
|
-
|
650
|
-
// Free a KV cache view. (use only for debugging purposes)
|
651
|
-
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
652
|
-
|
653
|
-
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
654
|
-
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
|
655
|
-
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
656
|
-
|
657
|
-
///
|
658
|
-
|
659
612
|
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
660
613
|
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
661
|
-
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx)
|
662
|
-
|
663
|
-
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
664
|
-
"use llama_kv_self_n_tokens instead");
|
614
|
+
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
|
615
|
+
"Use llama_kv_self_seq_pos_max() instead");
|
665
616
|
|
666
617
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
667
|
-
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx)
|
668
|
-
|
669
|
-
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
670
|
-
"use llama_kv_self_used_cells instead");
|
618
|
+
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
|
619
|
+
"Use llama_kv_self_seq_pos_max() instead");
|
671
620
|
|
672
621
|
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
673
622
|
LLAMA_API void llama_kv_self_clear(
|
@@ -726,10 +675,18 @@ extern "C" {
|
|
726
675
|
llama_pos p1,
|
727
676
|
int d);
|
728
677
|
|
678
|
+
// Returns the smallest position present in the KV cache for the specified sequence
|
679
|
+
// This is typically non-zero only for SWA caches
|
680
|
+
// Return -1 if the sequence is empty
|
681
|
+
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
682
|
+
struct llama_context * ctx,
|
683
|
+
llama_seq_id seq_id);
|
684
|
+
|
729
685
|
// Returns the largest position present in the KV cache for the specified sequence
|
686
|
+
// Return -1 if the sequence is empty
|
730
687
|
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
731
688
|
struct llama_context * ctx,
|
732
|
-
|
689
|
+
llama_seq_id seq_id);
|
733
690
|
|
734
691
|
// Defragment the KV cache
|
735
692
|
// This will be applied:
|
@@ -743,61 +700,6 @@ extern "C" {
|
|
743
700
|
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
744
701
|
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
745
702
|
|
746
|
-
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
747
|
-
struct llama_context * ctx),
|
748
|
-
"use llama_kv_self_clear instead");
|
749
|
-
|
750
|
-
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
751
|
-
struct llama_context * ctx,
|
752
|
-
llama_seq_id seq_id,
|
753
|
-
llama_pos p0,
|
754
|
-
llama_pos p1),
|
755
|
-
"use llama_kv_self_seq_rm instead");
|
756
|
-
|
757
|
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
758
|
-
struct llama_context * ctx,
|
759
|
-
llama_seq_id seq_id_src,
|
760
|
-
llama_seq_id seq_id_dst,
|
761
|
-
llama_pos p0,
|
762
|
-
llama_pos p1),
|
763
|
-
"use llama_kv_self_seq_cp instead");
|
764
|
-
|
765
|
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
766
|
-
struct llama_context * ctx,
|
767
|
-
llama_seq_id seq_id),
|
768
|
-
"use llama_kv_self_seq_keep instead");
|
769
|
-
|
770
|
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
771
|
-
struct llama_context * ctx,
|
772
|
-
llama_seq_id seq_id,
|
773
|
-
llama_pos p0,
|
774
|
-
llama_pos p1,
|
775
|
-
llama_pos delta),
|
776
|
-
"use llama_kv_self_seq_add instead");
|
777
|
-
|
778
|
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
779
|
-
struct llama_context * ctx,
|
780
|
-
llama_seq_id seq_id,
|
781
|
-
llama_pos p0,
|
782
|
-
llama_pos p1,
|
783
|
-
int d),
|
784
|
-
"use llama_kv_self_seq_div instead");
|
785
|
-
|
786
|
-
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
787
|
-
struct llama_context * ctx,
|
788
|
-
llama_seq_id seq_id),
|
789
|
-
"use llama_kv_self_seq_pos_max instead");
|
790
|
-
|
791
|
-
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
792
|
-
"use llama_kv_self_defrag instead");
|
793
|
-
|
794
|
-
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
795
|
-
"use llama_kv_self_can_shift instead");
|
796
|
-
|
797
|
-
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
798
|
-
"use llama_kv_self_update instead");
|
799
|
-
|
800
|
-
|
801
703
|
//
|
802
704
|
// State / sessions
|
803
705
|
//
|
@@ -925,18 +827,26 @@ extern "C" {
|
|
925
827
|
// Frees a batch of tokens allocated with llama_batch_init()
|
926
828
|
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
927
829
|
|
928
|
-
//
|
929
|
-
//
|
830
|
+
// Process a batch of tokens.
|
831
|
+
// In contrast to llama_decode() - this call does not use KV cache.
|
832
|
+
// For encode-decoder contexts, processes the batch using the encoder.
|
833
|
+
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
930
834
|
// 0 - success
|
931
835
|
// < 0 - error. the KV cache state is restored to the state before this call
|
932
836
|
LLAMA_API int32_t llama_encode(
|
933
837
|
struct llama_context * ctx,
|
934
838
|
struct llama_batch batch);
|
935
839
|
|
840
|
+
// Process a batch of tokens.
|
841
|
+
// Requires KV cache.
|
842
|
+
// For encode-decoder contexts, processes the batch using the decoder.
|
936
843
|
// Positive return values does not mean a fatal error, but rather a warning.
|
937
|
-
//
|
938
|
-
//
|
939
|
-
//
|
844
|
+
// Upon non-zero return values, the KV cache state is restored to the state before this call
|
845
|
+
// 0 - success
|
846
|
+
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
847
|
+
// 2 - aborted
|
848
|
+
// -1 - invalid input batch
|
849
|
+
// < -1 - error
|
940
850
|
LLAMA_API int32_t llama_decode(
|
941
851
|
struct llama_context * ctx,
|
942
852
|
struct llama_batch batch);
|
@@ -1429,6 +1339,37 @@ extern "C" {
|
|
1429
1339
|
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
|
1430
1340
|
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
|
1431
1341
|
|
1342
|
+
//
|
1343
|
+
// training
|
1344
|
+
//
|
1345
|
+
|
1346
|
+
// function that returns whether or not a given tensor contains trainable parameters
|
1347
|
+
typedef bool (*llama_opt_param_filter)(const struct lm_ggml_tensor * tensor, void * userdata);
|
1348
|
+
|
1349
|
+
// always returns true
|
1350
|
+
LLAMA_API bool llama_opt_param_filter_all(const struct lm_ggml_tensor * tensor, void * userdata);
|
1351
|
+
|
1352
|
+
struct llama_opt_params {
|
1353
|
+
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
|
1354
|
+
|
1355
|
+
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
|
1356
|
+
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
|
1357
|
+
|
1358
|
+
lm_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
1359
|
+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
1360
|
+
};
|
1361
|
+
|
1362
|
+
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
1363
|
+
|
1364
|
+
LLAMA_API void llama_opt_epoch(
|
1365
|
+
struct llama_context * lctx,
|
1366
|
+
lm_ggml_opt_dataset_t dataset,
|
1367
|
+
lm_ggml_opt_result_t result_train,
|
1368
|
+
lm_ggml_opt_result_t result_eval,
|
1369
|
+
int64_t idata_split,
|
1370
|
+
lm_ggml_opt_epoch_callback callback_train,
|
1371
|
+
lm_ggml_opt_epoch_callback callback_eval);
|
1372
|
+
|
1432
1373
|
#ifdef __cplusplus
|
1433
1374
|
}
|
1434
1375
|
#endif
|
@@ -13,10 +13,12 @@
|
|
13
13
|
#include <chrono>
|
14
14
|
#include <cstddef>
|
15
15
|
#include <cstdio>
|
16
|
+
#include <ctime>
|
16
17
|
#include <exception>
|
17
18
|
#include <iomanip>
|
18
19
|
#include <memory>
|
19
20
|
#include <sstream>
|
21
|
+
#include <stdexcept>
|
20
22
|
#include <string>
|
21
23
|
#include <vector>
|
22
24
|
|
@@ -393,8 +395,8 @@ class chat_template {
|
|
393
395
|
|
394
396
|
for (const auto & message_ : adjusted_messages) {
|
395
397
|
auto message = message_;
|
396
|
-
if (!message.contains("role") || !message.contains("content")) {
|
397
|
-
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
398
|
+
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
399
|
+
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
398
400
|
}
|
399
401
|
std::string role = message.at("role");
|
400
402
|
|
@@ -415,7 +417,6 @@ class chat_template {
|
|
415
417
|
}
|
416
418
|
}
|
417
419
|
if (polyfill_tool_calls) {
|
418
|
-
auto content = message.at("content");
|
419
420
|
auto tool_calls = json::array();
|
420
421
|
for (const auto & tool_call : message.at("tool_calls")) {
|
421
422
|
if (tool_call.at("type") != "function") {
|
@@ -434,8 +435,11 @@ class chat_template {
|
|
434
435
|
auto obj = json {
|
435
436
|
{"tool_calls", tool_calls},
|
436
437
|
};
|
437
|
-
if (
|
438
|
-
|
438
|
+
if (message.contains("content")) {
|
439
|
+
auto content = message.at("content");
|
440
|
+
if (!content.is_null() && !content.empty()) {
|
441
|
+
obj["content"] = content;
|
442
|
+
}
|
439
443
|
}
|
440
444
|
message["content"] = obj.dump(2);
|
441
445
|
message.erase("tool_calls");
|