cui-llama.rn 1.6.1 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +38 -5
- package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +41 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/llama-context.cpp
CHANGED
@@ -94,6 +94,8 @@ llama_context::llama_context(
|
|
94
94
|
|
95
95
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
96
96
|
|
97
|
+
cparams.op_offload = params.op_offload;
|
98
|
+
|
97
99
|
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
98
100
|
|
99
101
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
@@ -116,8 +118,6 @@ llama_context::llama_context(
|
|
116
118
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
117
119
|
}
|
118
120
|
|
119
|
-
logits_all = params.logits_all;
|
120
|
-
|
121
121
|
if (!hparams.vocab_only) {
|
122
122
|
// GPU backends
|
123
123
|
for (auto * dev : model.devices) {
|
@@ -177,8 +177,9 @@ llama_context::llama_context(
|
|
177
177
|
// init the memory module
|
178
178
|
if (!hparams.vocab_only) {
|
179
179
|
llama_memory_params params_mem = {
|
180
|
-
/*.type_k
|
181
|
-
/*.type_v
|
180
|
+
/*.type_k =*/ params.type_k,
|
181
|
+
/*.type_v =*/ params.type_v,
|
182
|
+
/*.swa_full =*/ params.swa_full,
|
182
183
|
};
|
183
184
|
|
184
185
|
memory.reset(model.create_memory(params_mem, cparams));
|
@@ -245,7 +246,7 @@ llama_context::llama_context(
|
|
245
246
|
}
|
246
247
|
}
|
247
248
|
|
248
|
-
sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
249
|
+
sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
|
249
250
|
|
250
251
|
if (pipeline_parallel) {
|
251
252
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(sched.get()));
|
@@ -253,7 +254,7 @@ llama_context::llama_context(
|
|
253
254
|
}
|
254
255
|
|
255
256
|
// reserve worst-case graph
|
256
|
-
if (!hparams.vocab_only) {
|
257
|
+
if (!hparams.vocab_only && memory) {
|
257
258
|
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
258
259
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
259
260
|
|
@@ -360,7 +361,9 @@ llama_context::llama_context(
|
|
360
361
|
}
|
361
362
|
}
|
362
363
|
|
363
|
-
llama_context::~llama_context()
|
364
|
+
llama_context::~llama_context() {
|
365
|
+
lm_ggml_opt_free(opt_ctx);
|
366
|
+
}
|
364
367
|
|
365
368
|
void llama_context::synchronize() {
|
366
369
|
lm_ggml_backend_sched_synchronize(sched.get());
|
@@ -702,6 +705,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
702
705
|
t_compute_start_us = lm_ggml_time_us();
|
703
706
|
}
|
704
707
|
|
708
|
+
embd_seq.clear();
|
709
|
+
|
705
710
|
n_queued_tokens += n_tokens;
|
706
711
|
|
707
712
|
const int64_t n_embd = hparams.n_embd;
|
@@ -763,12 +768,12 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
763
768
|
lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
764
769
|
LM_GGML_ASSERT(backend_embd != nullptr);
|
765
770
|
|
766
|
-
LM_GGML_ASSERT(embd != nullptr);
|
767
|
-
|
768
771
|
switch (cparams.pooling_type) {
|
769
772
|
case LLAMA_POOLING_TYPE_NONE:
|
770
773
|
{
|
771
774
|
// extract token embeddings
|
775
|
+
LM_GGML_ASSERT(embd != nullptr);
|
776
|
+
|
772
777
|
LM_GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
773
778
|
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
|
774
779
|
} break;
|
@@ -793,11 +798,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
793
798
|
} break;
|
794
799
|
case LLAMA_POOLING_TYPE_RANK:
|
795
800
|
{
|
796
|
-
//
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
+
// extract the rerank score - a single float per sequence
|
802
|
+
auto & embd_seq_out = embd_seq;
|
803
|
+
|
804
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
805
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
806
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
807
|
+
continue;
|
808
|
+
}
|
809
|
+
embd_seq_out[seq_id].resize(1);
|
810
|
+
lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
811
|
+
}
|
812
|
+
} break;
|
801
813
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
802
814
|
{
|
803
815
|
LM_GGML_ABORT("unknown pooling type");
|
@@ -835,16 +847,27 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
835
847
|
}
|
836
848
|
|
837
849
|
int llama_context::decode(llama_batch & inp_batch) {
|
850
|
+
if (!memory) {
|
851
|
+
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
|
852
|
+
return encode(inp_batch);
|
853
|
+
}
|
854
|
+
|
838
855
|
if (inp_batch.n_tokens == 0) {
|
839
856
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
840
857
|
return -1;
|
841
858
|
}
|
842
859
|
|
860
|
+
if (!inp_batch.pos) {
|
861
|
+
if (inp_batch.seq_id) {
|
862
|
+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
863
|
+
return -1;
|
864
|
+
}
|
865
|
+
}
|
866
|
+
|
843
867
|
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
844
868
|
|
845
869
|
// temporary allocate memory for the input batch if needed
|
846
|
-
|
847
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
|
870
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
848
871
|
|
849
872
|
const llama_batch & batch = batch_allocr.batch;
|
850
873
|
|
@@ -890,7 +913,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
890
913
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
891
914
|
n_outputs_all += batch.logits[i] != 0;
|
892
915
|
}
|
893
|
-
} else if (
|
916
|
+
} else if (embd_pooled) {
|
894
917
|
n_outputs_all = n_tokens_all;
|
895
918
|
} else {
|
896
919
|
// keep last output only
|
@@ -932,8 +955,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
932
955
|
|
933
956
|
// find KV slot
|
934
957
|
if (!kv_self->find_slot(ubatch)) {
|
935
|
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
936
|
-
|
937
958
|
return 1;
|
938
959
|
}
|
939
960
|
|
@@ -1689,10 +1710,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
1689
1710
|
}
|
1690
1711
|
}
|
1691
1712
|
|
1692
|
-
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
1693
1713
|
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1694
1714
|
|
1695
|
-
kv_self
|
1715
|
+
if (kv_self != nullptr) {
|
1716
|
+
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
1717
|
+
kv_self->state_write(io);
|
1718
|
+
}
|
1696
1719
|
|
1697
1720
|
return io.n_bytes();
|
1698
1721
|
}
|
@@ -1775,10 +1798,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
1775
1798
|
}
|
1776
1799
|
}
|
1777
1800
|
|
1778
|
-
|
1779
|
-
|
1801
|
+
if (memory) {
|
1802
|
+
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
1780
1803
|
|
1781
|
-
|
1804
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1805
|
+
|
1806
|
+
kv_self->state_read(io);
|
1807
|
+
}
|
1782
1808
|
|
1783
1809
|
return io.n_bytes();
|
1784
1810
|
}
|
@@ -1786,9 +1812,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
1786
1812
|
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
1787
1813
|
LM_GGML_UNUSED(seq_id);
|
1788
1814
|
|
1789
|
-
|
1815
|
+
if (memory) {
|
1816
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1790
1817
|
|
1791
|
-
|
1818
|
+
kv_self->state_write(io, seq_id);
|
1819
|
+
}
|
1792
1820
|
|
1793
1821
|
return io.n_bytes();
|
1794
1822
|
}
|
@@ -1796,9 +1824,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
1796
1824
|
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
1797
1825
|
LM_GGML_UNUSED(seq_id);
|
1798
1826
|
|
1799
|
-
|
1827
|
+
if (memory) {
|
1828
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1800
1829
|
|
1801
|
-
|
1830
|
+
kv_self->state_read(io, seq_id);
|
1831
|
+
}
|
1802
1832
|
|
1803
1833
|
return io.n_bytes();
|
1804
1834
|
}
|
@@ -1826,6 +1856,215 @@ void llama_context::perf_reset() {
|
|
1826
1856
|
t_p_eval_us = n_p_eval = 0;
|
1827
1857
|
}
|
1828
1858
|
|
1859
|
+
//
|
1860
|
+
// training
|
1861
|
+
//
|
1862
|
+
|
1863
|
+
static void llama_set_param(struct lm_ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
|
1864
|
+
if (!tensor || tensor->type != LM_GGML_TYPE_F32) {
|
1865
|
+
return;
|
1866
|
+
}
|
1867
|
+
if (!param_filter(tensor, userdata)) {
|
1868
|
+
return;
|
1869
|
+
}
|
1870
|
+
if (strcmp(tensor->name, "token_embd.weight") == 0) {
|
1871
|
+
return; // FIXME
|
1872
|
+
}
|
1873
|
+
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
|
1874
|
+
return; // FIXME
|
1875
|
+
}
|
1876
|
+
lm_ggml_set_param(tensor);
|
1877
|
+
}
|
1878
|
+
|
1879
|
+
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
|
1880
|
+
LM_GGML_ASSERT(!opt_ctx);
|
1881
|
+
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
|
1882
|
+
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
|
1883
|
+
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
1884
|
+
LM_GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
|
1885
|
+
LM_GGML_ASSERT(n_batch % n_ubatch == 0);
|
1886
|
+
|
1887
|
+
lm_ggml_opt_params opt_params = lm_ggml_opt_default_params(sched.get(), LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
|
1888
|
+
opt_params.opt_period = n_batch / n_ubatch;
|
1889
|
+
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
1890
|
+
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
1891
|
+
|
1892
|
+
opt_ctx = lm_ggml_opt_init(opt_params);
|
1893
|
+
|
1894
|
+
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
1895
|
+
void * param_filter_ud = lopt_params.param_filter_ud;
|
1896
|
+
|
1897
|
+
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
|
1898
|
+
llama_set_param(model->type_embd, param_filter, param_filter_ud);
|
1899
|
+
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
|
1900
|
+
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
|
1901
|
+
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
|
1902
|
+
llama_set_param(model->output_norm, param_filter, param_filter_ud);
|
1903
|
+
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
|
1904
|
+
llama_set_param(model->output, param_filter, param_filter_ud);
|
1905
|
+
llama_set_param(model->output_b, param_filter, param_filter_ud);
|
1906
|
+
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
|
1907
|
+
llama_set_param(model->cls, param_filter, param_filter_ud);
|
1908
|
+
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
1909
|
+
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
1910
|
+
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
1911
|
+
|
1912
|
+
for (struct llama_layer & layer : model->layers) {
|
1913
|
+
for (size_t i = 0; i < sizeof(layer)/sizeof(struct lm_ggml_tensor *); ++i) {
|
1914
|
+
llama_set_param(reinterpret_cast<struct lm_ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
|
1915
|
+
}
|
1916
|
+
}
|
1917
|
+
}
|
1918
|
+
|
1919
|
+
void llama_context::opt_epoch_iter(
|
1920
|
+
lm_ggml_opt_dataset_t dataset,
|
1921
|
+
lm_ggml_opt_result_t result,
|
1922
|
+
const std::vector<llama_token> & tokens,
|
1923
|
+
const std::vector<llama_token> & labels_sparse,
|
1924
|
+
llama_batch & batch,
|
1925
|
+
lm_ggml_opt_epoch_callback callback,
|
1926
|
+
bool train,
|
1927
|
+
int64_t idata_in_loop,
|
1928
|
+
int64_t ndata_in_loop,
|
1929
|
+
int64_t t_loop_start) {
|
1930
|
+
LM_GGML_ASSERT(opt_ctx);
|
1931
|
+
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
|
1932
|
+
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
1933
|
+
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
1934
|
+
|
1935
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
1936
|
+
|
1937
|
+
kv_self->clear();
|
1938
|
+
llama_kv_cache_guard kv_guard(kv_self);
|
1939
|
+
|
1940
|
+
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
1941
|
+
batch.n_tokens = n_batch;
|
1942
|
+
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
|
1943
|
+
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
|
1944
|
+
batch.pos [pos_batch] = pos_ctx + pos_batch;
|
1945
|
+
batch.n_seq_id[pos_batch] = 1;
|
1946
|
+
batch.seq_id [pos_batch][0] = 0;
|
1947
|
+
batch.logits [pos_batch] = true;
|
1948
|
+
}
|
1949
|
+
|
1950
|
+
const auto n_tokens_all = batch.n_tokens;
|
1951
|
+
|
1952
|
+
n_queued_tokens += n_tokens_all;
|
1953
|
+
|
1954
|
+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
1955
|
+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
1956
|
+
|
1957
|
+
embd_seq.clear();
|
1958
|
+
|
1959
|
+
int64_t n_outputs_all = n_tokens_all;
|
1960
|
+
|
1961
|
+
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
1962
|
+
|
1963
|
+
// reserve output buffer
|
1964
|
+
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
1965
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
1966
|
+
LM_GGML_ABORT("TODO: handle this error");
|
1967
|
+
};
|
1968
|
+
|
1969
|
+
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
1970
|
+
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
1971
|
+
|
1972
|
+
n_outputs = ubatch.n_tokens;
|
1973
|
+
|
1974
|
+
// TODO: not sure if this is needed
|
1975
|
+
if (!kv_self->find_slot(ubatch)) {
|
1976
|
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
1977
|
+
|
1978
|
+
LM_GGML_ABORT("TODO: handle this error");
|
1979
|
+
}
|
1980
|
+
|
1981
|
+
auto * gf = graph_init();
|
1982
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
1983
|
+
|
1984
|
+
struct lm_ggml_context * ctx_compute_opt;
|
1985
|
+
{
|
1986
|
+
const size_t size_gf = lm_ggml_graph_size(gf);
|
1987
|
+
const size_t size_meta = 4*size_gf*lm_ggml_tensor_overhead() + 2*lm_ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
1988
|
+
struct lm_ggml_init_params params = {
|
1989
|
+
/*.mem_size =*/ size_meta,
|
1990
|
+
/*.mem_buffer =*/ nullptr,
|
1991
|
+
/*.no_alloc =*/ true,
|
1992
|
+
};
|
1993
|
+
ctx_compute_opt = lm_ggml_init(params);
|
1994
|
+
}
|
1995
|
+
lm_ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
1996
|
+
lm_ggml_opt_alloc(opt_ctx, train);
|
1997
|
+
res->set_inputs(&ubatch);
|
1998
|
+
{
|
1999
|
+
struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
|
2000
|
+
LM_GGML_ASSERT(labels->ne[1] == n_ubatch);
|
2001
|
+
lm_ggml_set_zero(labels);
|
2002
|
+
const float onef = 1.0f;
|
2003
|
+
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
|
2004
|
+
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
|
2005
|
+
LM_GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
|
2006
|
+
lm_ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
|
2007
|
+
}
|
2008
|
+
}
|
2009
|
+
lm_ggml_opt_eval(opt_ctx, result);
|
2010
|
+
if (callback) {
|
2011
|
+
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
2012
|
+
}
|
2013
|
+
lm_ggml_free(ctx_compute_opt);
|
2014
|
+
}
|
2015
|
+
}
|
2016
|
+
|
2017
|
+
kv_guard.commit();
|
2018
|
+
}
|
2019
|
+
|
2020
|
+
void llama_context::opt_epoch(
|
2021
|
+
lm_ggml_opt_dataset_t dataset,
|
2022
|
+
lm_ggml_opt_result_t result_train,
|
2023
|
+
lm_ggml_opt_result_t result_eval,
|
2024
|
+
int64_t idata_split,
|
2025
|
+
lm_ggml_opt_epoch_callback callback_train,
|
2026
|
+
lm_ggml_opt_epoch_callback callback_eval) {
|
2027
|
+
const uint32_t n_ctx = this->n_ctx();
|
2028
|
+
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
|
2029
|
+
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
|
2030
|
+
const int64_t ndata = lm_ggml_opt_dataset_ndata(dataset);
|
2031
|
+
|
2032
|
+
LM_GGML_ASSERT(idata_split >= 0);
|
2033
|
+
LM_GGML_ASSERT(idata_split <= ndata);
|
2034
|
+
|
2035
|
+
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
|
2036
|
+
|
2037
|
+
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
2038
|
+
std::vector<llama_token> tokens(n_ctx);
|
2039
|
+
std::vector<llama_token> labels_sparse(n_ctx);
|
2040
|
+
|
2041
|
+
int64_t idata = 0;
|
2042
|
+
|
2043
|
+
int64_t t_loop_start = lm_ggml_time_us();
|
2044
|
+
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
|
2045
|
+
for (; idata < idata_split; ++idata) {
|
2046
|
+
constexpr bool train = true;
|
2047
|
+
const int64_t idata_in_loop = idata*ubatch_per_ctx;
|
2048
|
+
|
2049
|
+
lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
2050
|
+
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
|
2051
|
+
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
2052
|
+
}
|
2053
|
+
|
2054
|
+
t_loop_start = lm_ggml_time_us();
|
2055
|
+
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
|
2056
|
+
for (; idata < ndata; ++idata) {
|
2057
|
+
constexpr bool train = false;
|
2058
|
+
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
|
2059
|
+
|
2060
|
+
lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
2061
|
+
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
|
2062
|
+
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
2063
|
+
}
|
2064
|
+
|
2065
|
+
llama_batch_free(batch);
|
2066
|
+
}
|
2067
|
+
|
1829
2068
|
//
|
1830
2069
|
// interface implementation
|
1831
2070
|
//
|
@@ -1853,13 +2092,14 @@ llama_context_params llama_context_default_params() {
|
|
1853
2092
|
/*.cb_eval_user_data =*/ nullptr,
|
1854
2093
|
/*.type_k =*/ LM_GGML_TYPE_F16,
|
1855
2094
|
/*.type_v =*/ LM_GGML_TYPE_F16,
|
1856
|
-
/*.
|
2095
|
+
/*.abort_callback =*/ nullptr,
|
2096
|
+
/*.abort_callback_data =*/ nullptr,
|
1857
2097
|
/*.embeddings =*/ false,
|
1858
2098
|
/*.offload_kqv =*/ true,
|
1859
2099
|
/*.flash_attn =*/ false,
|
1860
2100
|
/*.no_perf =*/ true,
|
1861
|
-
/*.
|
1862
|
-
/*.
|
2101
|
+
/*.op_offload =*/ true,
|
2102
|
+
/*.swa_full =*/ true,
|
1863
2103
|
};
|
1864
2104
|
|
1865
2105
|
return result;
|
@@ -2054,65 +2294,51 @@ int32_t llama_apply_adapter_cvec(
|
|
2054
2294
|
return res ? 0 : -1;
|
2055
2295
|
}
|
2056
2296
|
|
2057
|
-
//
|
2058
|
-
// kv cache view
|
2059
|
-
//
|
2060
|
-
|
2061
|
-
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
2062
|
-
const auto * kv = ctx->get_kv_self();
|
2063
|
-
if (kv == nullptr) {
|
2064
|
-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
2065
|
-
return {};
|
2066
|
-
}
|
2067
|
-
|
2068
|
-
return llama_kv_cache_view_init(*kv, n_seq_max);
|
2069
|
-
}
|
2070
|
-
|
2071
|
-
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
2072
|
-
const auto * kv = ctx->get_kv_self();
|
2073
|
-
if (kv == nullptr) {
|
2074
|
-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
2075
|
-
return;
|
2076
|
-
}
|
2077
|
-
|
2078
|
-
llama_kv_cache_view_update(view, kv);
|
2079
|
-
}
|
2080
|
-
|
2081
2297
|
//
|
2082
2298
|
// kv cache
|
2083
2299
|
//
|
2084
2300
|
|
2085
2301
|
// deprecated
|
2086
|
-
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
2087
|
-
return llama_kv_self_n_tokens(ctx);
|
2088
|
-
}
|
2089
|
-
|
2090
2302
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
2091
2303
|
const auto * kv = ctx->get_kv_self();
|
2092
2304
|
if (!kv) {
|
2093
2305
|
return 0;
|
2094
2306
|
}
|
2095
2307
|
|
2096
|
-
|
2097
|
-
}
|
2308
|
+
int32_t res = 0;
|
2098
2309
|
|
2099
|
-
|
2100
|
-
|
2101
|
-
|
2310
|
+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
2311
|
+
const llama_pos p0 = kv->seq_pos_min(s);
|
2312
|
+
const llama_pos p1 = kv->seq_pos_max(s);
|
2313
|
+
|
2314
|
+
if (p0 >= 0) {
|
2315
|
+
res += (p1 - p0) + 1;
|
2316
|
+
}
|
2317
|
+
}
|
2318
|
+
|
2319
|
+
return res;
|
2102
2320
|
}
|
2103
2321
|
|
2322
|
+
// deprecated
|
2323
|
+
// note: this is the same as above - will be removed anyway, so it's ok
|
2104
2324
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
2105
2325
|
const auto * kv = ctx->get_kv_self();
|
2106
2326
|
if (!kv) {
|
2107
2327
|
return 0;
|
2108
2328
|
}
|
2109
2329
|
|
2110
|
-
|
2111
|
-
}
|
2330
|
+
int32_t res = 0;
|
2112
2331
|
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2332
|
+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
2333
|
+
const llama_pos p0 = kv->seq_pos_min(s);
|
2334
|
+
const llama_pos p1 = kv->seq_pos_max(s);
|
2335
|
+
|
2336
|
+
if (p0 >= 0) {
|
2337
|
+
res += (p1 - p0) + 1;
|
2338
|
+
}
|
2339
|
+
}
|
2340
|
+
|
2341
|
+
return res;
|
2116
2342
|
}
|
2117
2343
|
|
2118
2344
|
void llama_kv_self_clear(llama_context * ctx) {
|
@@ -2124,15 +2350,6 @@ void llama_kv_self_clear(llama_context * ctx) {
|
|
2124
2350
|
kv->clear();
|
2125
2351
|
}
|
2126
2352
|
|
2127
|
-
// deprecated
|
2128
|
-
bool llama_kv_cache_seq_rm(
|
2129
|
-
llama_context * ctx,
|
2130
|
-
llama_seq_id seq_id,
|
2131
|
-
llama_pos p0,
|
2132
|
-
llama_pos p1) {
|
2133
|
-
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
2134
|
-
}
|
2135
|
-
|
2136
2353
|
bool llama_kv_self_seq_rm(
|
2137
2354
|
llama_context * ctx,
|
2138
2355
|
llama_seq_id seq_id,
|
@@ -2146,16 +2363,6 @@ bool llama_kv_self_seq_rm(
|
|
2146
2363
|
return kv->seq_rm(seq_id, p0, p1);
|
2147
2364
|
}
|
2148
2365
|
|
2149
|
-
// deprecated
|
2150
|
-
void llama_kv_cache_seq_cp(
|
2151
|
-
llama_context * ctx,
|
2152
|
-
llama_seq_id seq_id_src,
|
2153
|
-
llama_seq_id seq_id_dst,
|
2154
|
-
llama_pos p0,
|
2155
|
-
llama_pos p1) {
|
2156
|
-
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
2157
|
-
}
|
2158
|
-
|
2159
2366
|
void llama_kv_self_seq_cp(
|
2160
2367
|
llama_context * ctx,
|
2161
2368
|
llama_seq_id seq_id_src,
|
@@ -2170,13 +2377,6 @@ void llama_kv_self_seq_cp(
|
|
2170
2377
|
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
2171
2378
|
}
|
2172
2379
|
|
2173
|
-
// deprecated
|
2174
|
-
void llama_kv_cache_seq_keep(
|
2175
|
-
llama_context * ctx,
|
2176
|
-
llama_seq_id seq_id) {
|
2177
|
-
llama_kv_self_seq_keep(ctx, seq_id);
|
2178
|
-
}
|
2179
|
-
|
2180
2380
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
2181
2381
|
auto * kv = ctx->get_kv_self();
|
2182
2382
|
if (!kv) {
|
@@ -2186,16 +2386,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
2186
2386
|
kv->seq_keep(seq_id);
|
2187
2387
|
}
|
2188
2388
|
|
2189
|
-
// deprecated
|
2190
|
-
void llama_kv_cache_seq_add(
|
2191
|
-
llama_context * ctx,
|
2192
|
-
llama_seq_id seq_id,
|
2193
|
-
llama_pos p0,
|
2194
|
-
llama_pos p1,
|
2195
|
-
llama_pos delta) {
|
2196
|
-
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
2197
|
-
}
|
2198
|
-
|
2199
2389
|
void llama_kv_self_seq_add(
|
2200
2390
|
llama_context * ctx,
|
2201
2391
|
llama_seq_id seq_id,
|
@@ -2210,16 +2400,6 @@ void llama_kv_self_seq_add(
|
|
2210
2400
|
kv->seq_add(seq_id, p0, p1, delta);
|
2211
2401
|
}
|
2212
2402
|
|
2213
|
-
// deprecated
|
2214
|
-
void llama_kv_cache_seq_div(
|
2215
|
-
llama_context * ctx,
|
2216
|
-
llama_seq_id seq_id,
|
2217
|
-
llama_pos p0,
|
2218
|
-
llama_pos p1,
|
2219
|
-
int d) {
|
2220
|
-
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
2221
|
-
}
|
2222
|
-
|
2223
2403
|
void llama_kv_self_seq_div(
|
2224
2404
|
llama_context * ctx,
|
2225
2405
|
llama_seq_id seq_id,
|
@@ -2234,25 +2414,24 @@ void llama_kv_self_seq_div(
|
|
2234
2414
|
kv->seq_div(seq_id, p0, p1, d);
|
2235
2415
|
}
|
2236
2416
|
|
2237
|
-
|
2238
|
-
|
2239
|
-
|
2417
|
+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
2418
|
+
const auto * kv = ctx->get_kv_self();
|
2419
|
+
if (!kv) {
|
2420
|
+
return -1;
|
2421
|
+
}
|
2422
|
+
|
2423
|
+
return kv->seq_pos_min(seq_id);
|
2240
2424
|
}
|
2241
2425
|
|
2242
2426
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
2243
2427
|
const auto * kv = ctx->get_kv_self();
|
2244
2428
|
if (!kv) {
|
2245
|
-
return
|
2429
|
+
return -1;
|
2246
2430
|
}
|
2247
2431
|
|
2248
2432
|
return kv->seq_pos_max(seq_id);
|
2249
2433
|
}
|
2250
2434
|
|
2251
|
-
// deprecated
|
2252
|
-
void llama_kv_cache_defrag(llama_context * ctx) {
|
2253
|
-
llama_kv_self_defrag(ctx);
|
2254
|
-
}
|
2255
|
-
|
2256
2435
|
void llama_kv_self_defrag(llama_context * ctx) {
|
2257
2436
|
auto * kv = ctx->get_kv_self();
|
2258
2437
|
if (!kv) {
|
@@ -2263,11 +2442,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
|
|
2263
2442
|
kv->defrag_sched(-1.0f);
|
2264
2443
|
}
|
2265
2444
|
|
2266
|
-
// deprecated
|
2267
|
-
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
2268
|
-
return llama_kv_self_can_shift(ctx);
|
2269
|
-
}
|
2270
|
-
|
2271
2445
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
2272
2446
|
const auto * kv = ctx->get_kv_self();
|
2273
2447
|
if (!kv) {
|
@@ -2277,11 +2451,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
2277
2451
|
return kv->get_can_shift();
|
2278
2452
|
}
|
2279
2453
|
|
2280
|
-
// deprecated
|
2281
|
-
void llama_kv_cache_update(llama_context * ctx) {
|
2282
|
-
llama_kv_self_update(ctx);
|
2283
|
-
}
|
2284
|
-
|
2285
2454
|
// llama state API
|
2286
2455
|
|
2287
2456
|
// deprecated
|
@@ -2404,7 +2573,21 @@ int32_t llama_encode(
|
|
2404
2573
|
int32_t llama_decode(
|
2405
2574
|
llama_context * ctx,
|
2406
2575
|
llama_batch batch) {
|
2407
|
-
|
2576
|
+
int ret = ctx->decode(batch);
|
2577
|
+
|
2578
|
+
// defrag and try again
|
2579
|
+
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
2580
|
+
if (ret == 1) {
|
2581
|
+
llama_kv_self_defrag(ctx);
|
2582
|
+
ret = ctx->decode(batch);
|
2583
|
+
|
2584
|
+
if (ret == 1) {
|
2585
|
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
2586
|
+
|
2587
|
+
return ret;
|
2588
|
+
}
|
2589
|
+
}
|
2590
|
+
|
2408
2591
|
if (ret != 0) {
|
2409
2592
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
2410
2593
|
}
|
@@ -2444,3 +2627,34 @@ void llama_perf_context_print(const llama_context * ctx) {
|
|
2444
2627
|
void llama_perf_context_reset(llama_context * ctx) {
|
2445
2628
|
ctx->perf_reset();
|
2446
2629
|
}
|
2630
|
+
|
2631
|
+
//
|
2632
|
+
// training
|
2633
|
+
//
|
2634
|
+
|
2635
|
+
bool llama_opt_param_filter_all(const struct lm_ggml_tensor * tensor, void * userdata) {
|
2636
|
+
LM_GGML_UNUSED(tensor);
|
2637
|
+
LM_GGML_UNUSED(userdata);
|
2638
|
+
return true;
|
2639
|
+
}
|
2640
|
+
|
2641
|
+
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
|
2642
|
+
ctx->opt_init(model, lopt_params);
|
2643
|
+
}
|
2644
|
+
|
2645
|
+
void llama_opt_epoch(
|
2646
|
+
struct llama_context * ctx,
|
2647
|
+
lm_ggml_opt_dataset_t dataset,
|
2648
|
+
lm_ggml_opt_result_t result_train,
|
2649
|
+
lm_ggml_opt_result_t result_eval,
|
2650
|
+
int64_t idata_split,
|
2651
|
+
lm_ggml_opt_epoch_callback callback_train,
|
2652
|
+
lm_ggml_opt_epoch_callback callback_eval) {
|
2653
|
+
ctx->opt_epoch(
|
2654
|
+
dataset,
|
2655
|
+
result_train,
|
2656
|
+
result_eval,
|
2657
|
+
idata_split,
|
2658
|
+
callback_train,
|
2659
|
+
callback_eval);
|
2660
|
+
}
|