@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -1,732 +1,846 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
4
5
  #include "llama-mmap.h"
6
+ #include "llama-model.h"
7
+ #include "llama-kv-cache.h"
5
8
 
6
9
  #include <cassert>
7
- #include <cmath>
8
10
  #include <cstring>
9
11
  #include <stdexcept>
12
+ #include <cinttypes>
10
13
 
11
- void llama_set_k_shift(struct llama_context & lctx) {
12
- const int64_t kv_size = lctx.kv_self.size;
14
+ //
15
+ // llama_context
16
+ //
13
17
 
14
- assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
18
+ llama_context::llama_context(
19
+ const llama_model & model,
20
+ llama_context_params params) :
21
+ model(model) {
22
+ LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
15
23
 
16
- int32_t * data = (int32_t *) lctx.inp_K_shift->data;
24
+ t_start_us = model.t_start_us;
25
+ t_load_us = model.t_load_us;
17
26
 
18
- for (int i = 0; i < kv_size; ++i) {
19
- data[i] = lctx.kv_self.cells[i].delta;
20
- }
21
- }
27
+ const auto & hparams = model.hparams;
22
28
 
23
- void llama_set_s_copy(struct llama_context & lctx) {
24
- const int64_t kv_size = lctx.kv_self.size;
29
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
30
+ cparams.n_threads = params.n_threads;
31
+ cparams.n_threads_batch = params.n_threads_batch;
32
+ cparams.yarn_ext_factor = params.yarn_ext_factor;
33
+ cparams.yarn_attn_factor = params.yarn_attn_factor;
34
+ cparams.yarn_beta_fast = params.yarn_beta_fast;
35
+ cparams.yarn_beta_slow = params.yarn_beta_slow;
36
+ cparams.defrag_thold = params.defrag_thold;
37
+ cparams.embeddings = params.embeddings;
38
+ cparams.offload_kqv = params.offload_kqv;
39
+ cparams.flash_attn = params.flash_attn;
40
+ cparams.no_perf = params.no_perf;
41
+ cparams.pooling_type = params.pooling_type;
42
+ cparams.warmup = false;
25
43
 
26
- assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
44
+ cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
45
+ cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
46
+ cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
27
47
 
28
- int32_t * data = (int32_t *) lctx.inp_s_copy->data;
48
+ cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
49
+ hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
50
+ hparams.n_ctx_train;
29
51
 
30
- for (int i = 0; i < kv_size; ++i) {
31
- data[i] = lctx.kv_self.cells[i].src;
32
- }
33
- }
52
+ cparams.cb_eval = params.cb_eval;
53
+ cparams.cb_eval_user_data = params.cb_eval_user_data;
34
54
 
35
- // llama input
55
+ auto rope_scaling_type = params.rope_scaling_type;
56
+ if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
57
+ rope_scaling_type = hparams.rope_scaling_type_train;
58
+ }
36
59
 
37
- static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
38
- // TODO move to hparams if a T5 variant appears that uses a different value
39
- const int64_t max_distance = 128;
60
+ if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
61
+ cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
62
+ }
40
63
 
41
- if (bidirectional) {
42
- n_buckets >>= 1;
64
+ if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
65
+ cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
43
66
  }
44
67
 
45
- const int64_t max_exact = n_buckets >> 1;
68
+ cparams.yarn_attn_factor *= hparams.rope_attn_factor;
46
69
 
47
- int32_t relative_position = x - y;
48
- int32_t relative_bucket = 0;
49
- if (bidirectional) {
50
- relative_bucket += (relative_position > 0) * n_buckets;
51
- relative_position = abs(relative_position);
70
+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
71
+ if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
72
+ cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
73
+ } else {
74
+ cparams.pooling_type = hparams.pooling_type;
75
+ }
76
+ }
77
+
78
+ if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
79
+ cparams.causal_attn = hparams.causal_attn;
52
80
  } else {
53
- relative_position = -std::min<int32_t>(relative_position, 0);
81
+ cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
54
82
  }
55
- int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
56
- relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
57
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
58
- return relative_bucket;
59
- }
60
83
 
61
- void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
62
- //
63
- // set input data
64
- //
84
+ // with causal attention, the batch size is limited by the context size
85
+ cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
65
86
 
66
- const auto & hparams = lctx.model.hparams;
67
- const auto & cparams = lctx.cparams;
68
- const auto & kv_self = lctx.kv_self;
87
+ // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
88
+ // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
89
+ // ref: https://github.com/ggerganov/llama.cpp/pull/5021
90
+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
91
+ if (cparams.n_batch < GGML_KQ_MASK_PAD) {
92
+ LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
93
+ cparams.n_batch = GGML_KQ_MASK_PAD;
94
+ }
69
95
 
70
- if (ubatch.token) {
71
- const int64_t n_tokens = ubatch.n_tokens;
96
+ cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
72
97
 
73
- ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
74
- }
98
+ const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
75
99
 
76
- if (ubatch.embd) {
77
- const int64_t n_embd = hparams.n_embd;
78
- const int64_t n_tokens = ubatch.n_tokens;
100
+ LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
101
+ LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
102
+ LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
103
+ LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
104
+ LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
105
+ LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
106
+ LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
107
+ LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
108
+ LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
79
109
 
80
- ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
110
+ if (n_ctx_per_seq < hparams.n_ctx_train) {
111
+ LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
112
+ __func__, n_ctx_per_seq, hparams.n_ctx_train);
81
113
  }
82
114
 
83
- if (ubatch.pos && lctx.inp_pos) {
84
- const int64_t n_tokens = ubatch.n_tokens;
85
- auto n_pos = lctx.n_pos_per_token;
86
- ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
115
+ if (n_ctx_per_seq > hparams.n_ctx_train) {
116
+ LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
+ __func__, n_ctx_per_seq, hparams.n_ctx_train);
87
118
  }
88
119
 
89
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
90
- //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
91
-
92
- if (!lctx.inp_out_ids) {
93
- LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
94
- } else {
95
- const int64_t n_tokens = ubatch.n_tokens;
120
+ logits_all = params.logits_all;
96
121
 
97
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
98
- int32_t * data = (int32_t *) lctx.inp_out_ids->data;
122
+ if (!hparams.vocab_only) {
123
+ // GPU backends
124
+ for (auto * dev : model.devices) {
125
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
126
+ if (backend == nullptr) {
127
+ throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
128
+ }
129
+ backends.emplace_back(backend);
130
+ }
99
131
 
100
- if (lctx.n_outputs == n_tokens) {
101
- for (int i = 0; i < n_tokens; ++i) {
102
- data[i] = i;
132
+ // add ACCEL backends (such as BLAS)
133
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
134
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
135
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
136
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
137
+ if (backend == nullptr) {
138
+ throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
103
139
  }
104
- } else if (ubatch.output) {
105
- int32_t n_outputs = 0;
106
- for (int i = 0; i < n_tokens; ++i) {
107
- if (ubatch.output[i]) {
108
- data[n_outputs++] = i;
109
- }
140
+ backends.emplace_back(backend);
141
+ }
142
+ }
143
+
144
+ // add CPU backend
145
+ backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
146
+ if (backend_cpu == nullptr) {
147
+ throw std::runtime_error("failed to initialize CPU backend");
148
+ }
149
+ backends.emplace_back(backend_cpu);
150
+
151
+ // create a list of the set_n_threads functions in the backends
152
+ for (auto & backend : backends) {
153
+ ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
154
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
155
+ if (reg) {
156
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
157
+ if (ggml_backend_set_n_threads_fn) {
158
+ set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
110
159
  }
111
- // the graph needs to have been passed the correct number of outputs
112
- GGML_ASSERT(lctx.n_outputs == n_outputs);
113
- } else if (lctx.n_outputs == 1) {
114
- // only keep last output
115
- data[0] = n_tokens - 1;
116
- } else {
117
- GGML_ASSERT(lctx.n_outputs == 0);
118
160
  }
119
161
  }
120
- }
121
162
 
122
- GGML_ASSERT(
123
- // (!a || b) is a logical implication (a -> b)
124
- // !hparams.causal_attn -> !cparams.causal_attn
125
- (hparams.causal_attn || !cparams.causal_attn) &&
126
- "causal attention is not supported by this model"
127
- );
163
+ llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
128
164
 
129
- if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
130
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
131
- if (cparams.causal_attn && !lctx.is_encoding) {
132
- const int64_t n_kv = kv_self.n;
133
- const int64_t n_tokens = ubatch.n_tokens;
134
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
135
- const int64_t n_seqs = ubatch.n_seqs;
165
+ // graph outputs buffer
166
+ {
167
+ // resized during inference when a batch uses more outputs
168
+ if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
169
+ throw std::runtime_error("failed to reserve initial output buffer");
170
+ }
136
171
 
172
+ LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
173
+ ggml_backend_buffer_name (buf_output.get()),
174
+ ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
175
+ }
176
+ }
137
177
 
138
- float * data = nullptr;
139
- float * data_swa = nullptr;
178
+ // init the memory module
179
+ // TODO: for now, always create a unified KV cache
180
+ if (!hparams.vocab_only) {
181
+ kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
140
182
 
141
- if (lctx.inp_KQ_mask) {
142
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
143
- data = (float *) lctx.inp_KQ_mask->data;
144
- }
183
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
145
184
 
146
- if (lctx.inp_KQ_mask_swa) {
147
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
148
- data_swa = (float *) lctx.inp_KQ_mask_swa->data;
149
- }
185
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
150
186
 
151
- // For causal attention, use only the previous KV cells
152
- // of the correct sequence for each token of the ubatch.
153
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
154
- for (int h = 0; h < 1; ++h) {
155
- for (int s = 0; s < n_seqs; ++s) {
156
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
157
-
158
- for (int j = 0; j < n_seq_tokens; ++j) {
159
- const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
160
-
161
- for (int i = 0; i < n_kv; ++i) {
162
- float f;
163
- if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
164
- f = -INFINITY;
165
- } else {
166
- if (hparams.use_alibi) {
167
- f = -std::abs(kv_self.cells[i].pos - pos);
168
- } else {
169
- f = 0.0f;
170
- }
171
- }
187
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
172
188
 
173
- if (data) {
174
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
175
- }
189
+ uint32_t kv_size = cparams.n_ctx;
190
+ ggml_type type_k = params.type_k;
191
+ ggml_type type_v = params.type_v;
176
192
 
177
- // may need to cut off old tokens for sliding window
178
- if (data_swa) {
179
- if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
180
- f = -INFINITY;
181
- }
182
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
183
- }
184
- }
185
- }
186
- }
193
+ if (llama_model_is_recurrent(&model)) {
194
+ // Mamba needs at least as many KV cells as there are sequences kept at any time
195
+ kv_size = std::max((uint32_t) 1, params.n_seq_max);
196
+ // it's probably best to keep as much precision as possible for the states
197
+ type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
198
+ type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
199
+ }
187
200
 
188
- if (data) {
189
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
190
- for (int j = 0; j < n_kv; ++j) {
191
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
192
- }
193
- }
194
- }
201
+ GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
202
+ GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
195
203
 
196
- if (data_swa) {
197
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
198
- for (int j = 0; j < n_kv; ++j) {
199
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
200
- }
201
- }
202
- }
203
- }
204
- } else {
205
- const int64_t n_tokens = ubatch.n_tokens;
206
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
207
- const int64_t n_seqs = ubatch.n_seqs;
208
- // when using kv cache, the mask needs to match the kv cache size
209
- const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
210
-
211
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
212
-
213
- float * data = (float *) lctx.inp_KQ_mask->data;
214
-
215
- for (int h = 0; h < 1; ++h) {
216
- for (int s1 = 0; s1 < n_seqs; ++s1) {
217
- const llama_seq_id seq_id = ubatch.seq_id[s1][0];
218
-
219
- for (int j = 0; j < n_seq_tokens; ++j) {
220
- const int32_t tj = s1*n_seq_tokens + j;
221
-
222
- for (int s0 = 0; s0 < n_seqs; ++s0) {
223
- for (int i = 0; i < n_seq_tokens; ++i) {
224
- const int32_t ti = s0*n_seq_tokens + i;
225
- float f = -INFINITY;
226
-
227
- for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
228
- if (ubatch.seq_id[s0][s] == seq_id) {
229
- if (hparams.use_alibi) {
230
- f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
231
- } else {
232
- f = 0.0f;
233
- }
234
- break;
235
- }
236
- }
237
-
238
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
239
- }
240
- }
204
+ if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
205
+ throw std::runtime_error("failed to initialize self-attention cache");
206
+ }
241
207
 
242
- for (int i = n_tokens; i < n_stride; ++i) {
243
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
244
- }
245
- }
246
- }
247
- }
208
+ {
209
+ const size_t memory_size_k = kv_self->size_k_bytes();
210
+ const size_t memory_size_v = kv_self->size_v_bytes();
211
+
212
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
213
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
214
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
215
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
248
216
  }
249
217
  }
250
218
 
251
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
252
- const int64_t n_tokens = ubatch.n_tokens;
253
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
254
- const int64_t n_seqs = ubatch.n_seqs;
219
+ // init backends
220
+ if (!hparams.vocab_only) {
221
+ LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
255
222
 
256
- GGML_ASSERT(lctx.inp_mean);
257
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
223
+ backend_buft.clear();
224
+ backend_ptrs.clear();
258
225
 
259
- float * data = (float *) lctx.inp_mean->data;
260
- memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
226
+ for (auto & backend : backends) {
227
+ auto * buft = ggml_backend_get_default_buffer_type(backend.get());
228
+ auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
261
229
 
262
- std::vector<uint64_t> sum(n_tokens, 0);
230
+ if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
231
+ // use the host buffer of the first device CPU for faster transfer of the intermediate state
232
+ auto * dev = model.devices[0];
233
+ auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
234
+ if (host_buft) {
235
+ buft = host_buft;
236
+ }
237
+ }
263
238
 
264
- for (int s = 0; s < n_seqs; ++s) {
265
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
239
+ backend_buft.push_back(buft);
240
+ backend_ptrs.push_back(backend.get());
241
+ }
266
242
 
267
- // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
268
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
243
+ LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
269
244
 
270
- sum[seq_id] += ubatch.n_seq_tokens;
271
- }
245
+ const size_t max_nodes = this->graph_max_nodes();
272
246
 
273
- std::vector<float> div(n_tokens, 0.0f);
274
- for (int i = 0; i < n_tokens; ++i) {
275
- const uint64_t s = sum[i];
276
- if (s > 0) {
277
- div[i] = 1.0f/float(s);
247
+ LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
248
+
249
+ // buffer used to store the computation graph and the tensor meta data
250
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
251
+
252
+ // TODO: move these checks to ggml_backend_sched
253
+ // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
254
+ bool pipeline_parallel =
255
+ model.n_devices() > 1 &&
256
+ model.params.n_gpu_layers > (int) model.hparams.n_layer &&
257
+ model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
258
+ cparams.offload_kqv;
259
+
260
+ // pipeline parallelism requires support for async compute and events in all devices
261
+ if (pipeline_parallel) {
262
+ for (auto & backend : backends) {
263
+ auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
264
+ if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
265
+ // ignore CPU backend
266
+ continue;
267
+ }
268
+ auto * dev = ggml_backend_get_device(backend.get());
269
+ ggml_backend_dev_props props;
270
+ ggml_backend_dev_get_props(dev, &props);
271
+ if (!props.caps.async || !props.caps.events) {
272
+ // device does not support async compute or events
273
+ pipeline_parallel = false;
274
+ break;
275
+ }
278
276
  }
279
277
  }
280
278
 
281
- for (int s = 0; s < n_seqs; ++s) {
282
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
279
+ sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
283
280
 
284
- for (int i = 0; i < n_seq_tokens; ++i) {
285
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
286
- }
281
+ if (pipeline_parallel) {
282
+ LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
287
283
  }
288
284
  }
289
285
 
290
- if (cparams.embeddings && (
291
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
292
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
293
- const int64_t n_tokens = ubatch.n_tokens;
294
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
295
- const int64_t n_seqs = ubatch.n_seqs;
286
+ // reserve worst-case graph
287
+ if (!hparams.vocab_only) {
288
+ const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
289
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
296
290
 
297
- GGML_ASSERT(lctx.inp_cls);
298
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
291
+ llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
299
292
 
300
- uint32_t * data = (uint32_t *) lctx.inp_cls->data;
301
- memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
293
+ // restore later
294
+ // TODO: something cleaner
295
+ const auto n_outputs_save = n_outputs;
302
296
 
303
- for (int s = 0; s < n_seqs; ++s) {
304
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
297
+ // max number of outputs
298
+ n_outputs = n_tokens;
305
299
 
306
- // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
307
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
300
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
308
301
 
309
- for (int i = 0; i < n_seq_tokens; ++i) {
310
- const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
302
+ int n_splits_pp = -1;
303
+ int n_nodes_pp = -1;
311
304
 
312
- if (pos == 0) {
313
- data[seq_id] = s*n_seq_tokens + i;
314
- }
315
- }
316
- }
317
- }
305
+ int n_splits_tg = -1;
306
+ int n_nodes_tg = -1;
318
307
 
319
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
320
- const int64_t n_tokens = ubatch.n_tokens;
321
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
322
- const int64_t n_seqs = ubatch.n_seqs;
308
+ // simulate full KV cache
309
+ kv_self->n = kv_self->size;
323
310
 
324
- GGML_ASSERT(lctx.inp_cls);
325
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
311
+ cross.v_embd.clear();
326
312
 
327
- uint32_t * data = (uint32_t *) lctx.inp_cls->data;
328
- memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
313
+ // reserve pp graph first so that buffers are only allocated once
314
+ {
315
+ llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
316
+ auto * gf = graph_init();
317
+ graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
318
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
319
+ throw std::runtime_error("failed to allocate compute pp buffers");
320
+ }
329
321
 
330
- std::vector<int> last_pos(n_tokens, -1);
331
- std::vector<int> last_row(n_tokens, -1);
322
+ n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
323
+ n_nodes_pp = ggml_graph_n_nodes(gf);
324
+ }
332
325
 
333
- for (int s = 0; s < n_seqs; ++s) {
334
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
326
+ // reserve with tg graph to get the number of splits and nodes
327
+ {
328
+ llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
329
+ auto * gf = graph_init();
330
+ graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
331
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
332
+ throw std::runtime_error("failed to allocate compute tg buffers");
333
+ }
334
+ n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
335
+ n_nodes_tg = ggml_graph_n_nodes(gf);
336
+ }
335
337
 
336
- // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
337
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
338
+ // reserve again with pp graph to avoid ggml-alloc reallocations during inference
339
+ {
340
+ llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
341
+ auto * gf = graph_init();
342
+ graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
343
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
344
+ throw std::runtime_error("failed to allocate compute pp buffers");
345
+ }
346
+ }
338
347
 
339
- for (int i = 0; i < n_seq_tokens; ++i) {
340
- const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
348
+ n_outputs = n_outputs_save;
341
349
 
342
- if (pos >= last_pos[seq_id]) {
343
- last_pos[seq_id] = pos;
344
- last_row[seq_id] = s*n_seq_tokens + i;
345
- }
350
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
351
+ ggml_backend_t backend = backend_ptrs[i];
352
+ ggml_backend_buffer_type_t buft = backend_buft[i];
353
+ size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
354
+ if (size > 1) {
355
+ LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
356
+ ggml_backend_buft_name(buft),
357
+ size / 1024.0 / 1024.0);
346
358
  }
347
359
  }
348
360
 
349
- for (int i = 0; i < n_tokens; ++i) {
350
- if (last_row[i] >= 0) {
351
- data[i] = last_row[i];
352
- }
361
+ if (n_nodes_pp == n_nodes_tg) {
362
+ LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
363
+ } else {
364
+ LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
353
365
  }
354
- }
355
366
 
356
- if (kv_self.recurrent) {
357
- const int64_t n_kv = kv_self.n;
367
+ if (n_splits_pp == n_splits_tg) {
368
+ LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
369
+ } else {
370
+ LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
371
+ }
372
+ }
373
+ }
358
374
 
359
- if (lctx.inp_s_mask) {
360
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
361
- float * data = (float *) lctx.inp_s_mask->data;
375
+ llama_context::~llama_context() = default;
362
376
 
363
- // clear unused states
364
- for (int i = 0; i < n_kv; ++i) {
365
- const uint32_t cell_id = i + kv_self.head;
366
- llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
377
+ void llama_context::synchronize() {
378
+ ggml_backend_sched_synchronize(sched.get());
367
379
 
368
- data[i] = (float) (kv_cell.src >= 0);
380
+ // FIXME: if multiple single tokens are evaluated without a synchronization,
381
+ // the stats will be added to the prompt evaluation stats
382
+ // this should only happen when using batch size 1 to evaluate a batch
369
383
 
370
- // only clear once
371
- if (kv_cell.src < 0) {
372
- kv_cell.src = cell_id;
373
- }
374
- }
384
+ // add the evaluation to the stats
385
+ if (n_queued_tokens == 1) {
386
+ if (!cparams.no_perf) {
387
+ t_eval_us += ggml_time_us() - t_compute_start_us;
375
388
  }
389
+ n_eval++;
390
+ } else if (n_queued_tokens > 1) {
391
+ if (!cparams.no_perf) {
392
+ t_p_eval_us += ggml_time_us() - t_compute_start_us;
393
+ }
394
+ n_p_eval += n_queued_tokens;
395
+ }
376
396
 
377
- if (lctx.inp_s_copy) {
378
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
379
- int32_t * data = (int32_t *) lctx.inp_s_copy->data;
397
+ // get a more accurate load time, upon first eval
398
+ if (n_queued_tokens > 0 && !has_evaluated_once) {
399
+ t_load_us = ggml_time_us() - t_start_us;
400
+ has_evaluated_once = true;
401
+ }
380
402
 
381
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
382
- for (uint32_t i = 0; i < n_kv; ++i) {
383
- const uint32_t cell_id = i + kv_self.head;
384
- llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
403
+ n_queued_tokens = 0;
404
+ t_compute_start_us = 0;
405
+ }
385
406
 
386
- // prevent out-of-bound sources
387
- if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
388
- kv_cell.src = cell_id;
389
- }
407
+ const llama_model & llama_context::get_model() const {
408
+ return model;
409
+ }
390
410
 
391
- data[i] = kv_cell.src;
411
+ uint32_t llama_context::n_ctx() const {
412
+ return cparams.n_ctx;
413
+ }
392
414
 
393
- // ensure copy only happens once
394
- if (kv_cell.src != (int32_t) cell_id) {
395
- kv_cell.src = cell_id;
396
- }
397
- }
398
- }
399
- }
415
+ uint32_t llama_context::n_ctx_per_seq() const {
416
+ return cparams.n_ctx / cparams.n_seq_max;
417
+ }
400
418
 
401
- if (lctx.inp_pos_bucket) {
402
- const int64_t n_tokens = ubatch.n_tokens;
419
+ uint32_t llama_context::n_batch() const {
420
+ return cparams.n_batch;
421
+ }
403
422
 
404
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
405
- GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
423
+ uint32_t llama_context::n_ubatch() const {
424
+ return cparams.n_ubatch;
425
+ }
406
426
 
407
- int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
427
+ uint32_t llama_context::n_seq_max() const {
428
+ return cparams.n_seq_max;
429
+ }
408
430
 
409
- if (!lctx.is_encoding) {
410
- const int64_t n_kv = kv_self.n;
411
- for (int h = 0; h < 1; ++h) {
412
- for (int j = 0; j < n_tokens; ++j) {
413
- for (int i = 0; i < n_kv; ++i) {
414
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
415
- }
416
- }
417
- }
418
- } else {
419
- for (int h = 0; h < 1; ++h) {
420
- for (int j = 0; j < n_tokens; ++j) {
421
- for (int i = 0; i < n_tokens; ++i) {
422
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
423
- }
431
+ uint32_t llama_context::n_threads() const {
432
+ return cparams.n_threads;
433
+ }
434
+
435
+ uint32_t llama_context::n_threads_batch() const {
436
+ return cparams.n_threads_batch;
437
+ }
438
+
439
+ llama_kv_cache * llama_context::get_kv_self() {
440
+ return kv_self.get();
441
+ }
442
+
443
+ const llama_kv_cache * llama_context::get_kv_self() const {
444
+ return kv_self.get();
445
+ }
446
+
447
+ ggml_tensor * llama_context::build_rope_shift(
448
+ ggml_context * ctx0,
449
+ ggml_tensor * cur,
450
+ ggml_tensor * shift,
451
+ ggml_tensor * factors,
452
+ float freq_base,
453
+ float freq_scale,
454
+ ggml_backend_buffer * bbuf) const {
455
+ const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
456
+
457
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor;
458
+ const auto & yarn_attn_factor = cparams.yarn_attn_factor;
459
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast;
460
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow;
461
+
462
+ const auto & hparams = model.hparams;
463
+
464
+ const auto & n_rot = hparams.n_rot;
465
+ const auto & rope_type = hparams.rope_type;
466
+
467
+ ggml_tensor * tmp;
468
+
469
+ if (ggml_is_quantized(cur->type)) {
470
+ // dequantize to f32 -> RoPE -> quantize back
471
+ tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
472
+
473
+ if (bbuf) {
474
+ for (const auto & backend : backends) {
475
+ // Figure out which backend KV cache belongs to
476
+ if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
477
+ ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
478
+ break;
424
479
  }
425
480
  }
426
481
  }
427
- }
428
482
 
429
- if (!lctx.is_encoding && lctx.inp_embd_enc) {
430
- assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
431
- assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
483
+ tmp = ggml_rope_ext_inplace(ctx0, tmp,
484
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
485
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
432
486
 
433
- ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
487
+ tmp = ggml_cpy(ctx0, tmp, cur);
488
+ } else {
489
+ // we rotate only the first n_rot dimensions
490
+ tmp = ggml_rope_ext_inplace(ctx0, cur,
491
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
492
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
434
493
  }
435
494
 
436
- if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
437
- const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
438
- const int64_t n_tokens = ubatch.n_tokens;
495
+ return tmp;
496
+ }
439
497
 
440
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
441
- GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
498
+ class llm_graph_input_k_shift : public llm_graph_input_i {
499
+ public:
500
+ llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
501
+ virtual ~llm_graph_input_k_shift() = default;
442
502
 
443
- float * data = (float *) lctx.inp_KQ_mask_cross->data;
503
+ void set_input(const llama_ubatch * ubatch) override;
444
504
 
445
- for (int h = 0; h < 1; ++h) {
446
- for (int j = 0; j < n_tokens; ++j) {
447
- for (int i = 0; i < n_output_enc; ++i) {
448
- float f = -INFINITY;
449
- for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
450
- const llama_seq_id seq_id = ubatch.seq_id[j][s];
451
- if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
452
- f = 0.0f;
453
- }
454
- }
455
- data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
456
- }
457
- }
505
+ ggml_tensor * k_shift; // I32 [kv_size]
458
506
 
459
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
460
- for (int j = 0; j < n_output_enc; ++j) {
461
- data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
462
- }
463
- }
507
+ const llama_kv_cache_unified * kv_self;
508
+ };
509
+
510
+ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
511
+ GGML_UNUSED(ubatch);
512
+
513
+ if (k_shift) {
514
+ assert(ggml_backend_buffer_is_host(k_shift->buffer));
515
+
516
+ int32_t * data = (int32_t *) k_shift->data;
517
+
518
+ for (uint32_t i = 0; i < kv_self->size; ++i) {
519
+ data[i] = kv_self->cells[i].delta;
464
520
  }
465
521
  }
466
522
  }
467
523
 
468
- // llama output
524
+ llm_graph_result_ptr llama_context::build_kv_self_shift(
525
+ ggml_context * ctx0,
526
+ ggml_cgraph * gf) const {
527
+ auto res = std::make_unique<llm_graph_result>();
469
528
 
470
- size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
471
- const auto & cparams = lctx.cparams;
472
- const auto & hparams = lctx.model.hparams;
473
- const auto & vocab = lctx.model.vocab;
529
+ const auto & hparams = model.hparams;
474
530
 
475
- const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
531
+ const auto & n_layer = hparams.n_layer;
476
532
 
477
- const auto n_batch = cparams.n_batch;
478
- const auto n_vocab = vocab.n_tokens();
479
- const auto n_embd = hparams.n_embd;
533
+ const auto & n_embd_head_k = hparams.n_embd_head_k;
534
+ //const auto & n_embd_head_v = hparams.n_embd_head_v;
480
535
 
481
- // TODO: use a per-batch flag for logits presence instead
482
- const bool has_logits = !cparams.embeddings;
483
- const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
536
+ //GGML_ASSERT(kv_self->size == n_ctx);
484
537
 
485
- const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
486
- const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
538
+ auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
487
539
 
488
- if (lctx.output_ids.empty()) {
489
- // init, never resized afterwards
490
- lctx.output_ids.resize(n_batch);
491
- }
540
+ inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
541
+ ggml_set_input(inp->k_shift);
492
542
 
493
- const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
494
- const size_t new_size = (logits_size + embd_size) * sizeof(float);
543
+ for (uint32_t il = 0; il < n_layer; ++il) {
544
+ const int64_t n_head_kv = hparams.n_head_kv(il);
545
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
495
546
 
496
- // alloc only when more than the current capacity is required
497
- // TODO: also consider shrinking the buffer
498
- if (!lctx.buf_output || prev_size < new_size) {
499
- if (lctx.buf_output) {
500
- #ifndef NDEBUG
501
- // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
502
- LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
503
- #endif
504
- lctx.buf_output = nullptr;
505
- lctx.logits = nullptr;
506
- lctx.embd = nullptr;
507
- }
547
+ const bool is_swa = hparams.is_swa(il);
508
548
 
509
- auto * buft = ggml_backend_cpu_buffer_type();
510
- // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
511
- auto * output_dev = lctx.model.dev_output();
512
- auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
513
- if (output_dev_host_buft) {
514
- buft = output_dev_host_buft;
515
- }
516
- lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
517
- if (lctx.buf_output == nullptr) {
518
- LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
519
- return 0;
520
- }
549
+ // note: the swa rope params could become part of the cparams in the future
550
+ // if we decide to make them configurable, like the non-sliding ones
551
+ const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
552
+ const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
553
+
554
+ ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
555
+
556
+ ggml_tensor * k =
557
+ ggml_view_3d(ctx0, kv_self->k_l[il],
558
+ n_embd_head_k, n_head_kv, kv_self->size,
559
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
560
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
561
+ 0);
562
+
563
+ ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
564
+
565
+ ggml_build_forward_expand(gf, cur);
521
566
  }
522
567
 
523
- float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get());
568
+ res->add_input(std::move(inp));
524
569
 
525
- lctx.logits = has_logits ? output_base : nullptr;
526
- lctx.embd = has_embd ? output_base + logits_size : nullptr;
570
+ return res;
571
+ }
527
572
 
528
- lctx.output_size = n_outputs_max;
529
- lctx.logits_size = logits_size;
530
- lctx.embd_size = embd_size;
573
+ llm_graph_result_ptr llama_context::build_kv_self_defrag(
574
+ ggml_context * ctx0,
575
+ ggml_cgraph * gf) const {
576
+ auto res = std::make_unique<llm_graph_result>();
531
577
 
532
- // set all ids as invalid (negative)
533
- std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
578
+ const auto & hparams = model.hparams;
534
579
 
535
- ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
580
+ const auto & ids = kv_self->defrag_info.ids;
536
581
 
537
- lctx.n_outputs = 0;
582
+ #if 0
583
+ // CPU defrag
584
+ //
585
+ // TODO: optimizations are possible:
586
+ // - multiple threads
587
+ // - avoid copying to the host memory when already there
588
+ //
589
+ // likely not worth the effort, as we have ggml_graph based defrag
590
+ //
538
591
 
539
- return n_outputs_max;
540
- }
592
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
593
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
541
594
 
542
- void llama_output_reorder(struct llama_context & ctx) {
543
- std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
544
- if (!out_ids.empty()) {
545
- const uint32_t n_vocab = ctx.model.vocab.n_tokens();
546
- const uint32_t n_embd = ctx.model.hparams.n_embd;
595
+ const uint32_t kv_size = size;
547
596
 
548
- const int32_t n_outputs = ctx.n_outputs;
549
- GGML_ASSERT((size_t) n_outputs == out_ids.size());
597
+ std::vector<uint8_t> buf_k;
598
+ std::vector<uint8_t> buf_v;
550
599
 
551
- // TODO: is there something more efficient which also minimizes swaps?
552
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
553
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
554
- int32_t j_min = i;
555
- for (int32_t j = i + 1; j < n_outputs; ++j) {
556
- if (out_ids[j] < out_ids[j_min]) {
557
- j_min = j;
558
- }
600
+ for (uint32_t il = 0; il < n_layer; ++il) {
601
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
602
+ const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
603
+
604
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
605
+ const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
606
+
607
+ buf_k.resize(k_size);
608
+ buf_v.resize(v_size);
609
+
610
+ ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
611
+ ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
612
+
613
+ // batch move [i, i+nm) to [id, id+nm)
614
+ // note: cells can move only to a lower index
615
+ for (uint32_t i = 0; i < n_kv; ++i) {
616
+ const uint32_t id = ids[i];
617
+
618
+ if (i == id || id == n_kv) {
619
+ continue;
559
620
  }
560
- if (j_min == i) { continue; }
561
- std::swap(out_ids[i], out_ids[j_min]);
562
- if (ctx.logits_size > 0) {
563
- for (uint32_t k = 0; k < n_vocab; k++) {
564
- std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
565
- }
621
+
622
+ uint32_t nm = 1;
623
+
624
+ while (i + nm < n_kv && ids[i + nm] == id + nm) {
625
+ nm++;
566
626
  }
567
- if (ctx.embd_size > 0) {
568
- for (uint32_t k = 0; k < n_embd; k++) {
569
- std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
627
+
628
+ // move keys
629
+ {
630
+ const int64_t os = i*k_size_row;
631
+ const int64_t od = id*k_size_row;
632
+
633
+ memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
634
+ }
635
+
636
+ // move values (note: they are transposed)
637
+ {
638
+ const int64_t os = i;
639
+ const int64_t od = id;
640
+
641
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
642
+ memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
570
643
  }
571
644
  }
645
+
646
+ i += nm - 1;
572
647
  }
573
- std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
574
- for (int32_t i = 0; i < n_outputs; ++i) {
575
- ctx.output_ids[out_ids[i]] = i;
576
- }
577
- out_ids.clear();
578
- }
579
- }
580
648
 
581
- //
582
- // interface implementation
583
- //
649
+ ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
650
+ ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
651
+ }
652
+ #else
653
+ for (uint32_t i = 0; i < ids.size(); ++i) {
654
+ const uint32_t id = ids[i];
584
655
 
585
- void llama_free(struct llama_context * ctx) {
586
- delete ctx;
587
- }
656
+ if (i == id || id == ids.size()) {
657
+ continue;
658
+ }
588
659
 
589
- uint32_t llama_n_ctx(const struct llama_context * ctx) {
590
- return ctx->cparams.n_ctx;
591
- }
660
+ uint32_t nm = 1;
592
661
 
593
- uint32_t llama_n_batch(const struct llama_context * ctx) {
594
- return ctx->cparams.n_batch;
595
- }
662
+ while (i + nm < ids.size() && ids[i + nm] == id + nm) {
663
+ nm++;
664
+ }
596
665
 
597
- uint32_t llama_n_ubatch(const struct llama_context * ctx) {
598
- return ctx->cparams.n_ubatch;
599
- }
666
+ for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
667
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
668
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
669
+
670
+ ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
671
+ n_embd_k_gqa, nm,
672
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
673
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
674
+
675
+ ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
676
+ n_embd_k_gqa, nm,
677
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
678
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
679
+
680
+ ggml_tensor * view_v_src;
681
+ ggml_tensor * view_v_dst;
682
+
683
+ if (cparams.flash_attn) {
684
+ // NOTE: the V cache is not transposed when using flash attention
685
+ view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
686
+ n_embd_v_gqa, nm,
687
+ ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
688
+ ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
689
+
690
+ view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
691
+ n_embd_v_gqa, nm,
692
+ ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
693
+ ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
694
+ } else {
695
+ view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
696
+ nm, n_embd_v_gqa,
697
+ ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
698
+ ggml_row_size(kv_self->v_l[il]->type, i));
699
+
700
+ view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
701
+ nm, n_embd_v_gqa,
702
+ ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
703
+ ggml_row_size(kv_self->v_l[il]->type, id));
704
+ }
600
705
 
601
- uint32_t llama_n_seq_max(const struct llama_context * ctx) {
602
- return ctx->kv_self.size;
603
- }
706
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
707
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
708
+ }
604
709
 
605
- const struct llama_model * llama_get_model(const struct llama_context * ctx) {
606
- return &ctx->model;
607
- }
710
+ i += nm - 1;
711
+ }
608
712
 
609
- enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
610
- return ctx->cparams.pooling_type;
611
- }
713
+ //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
714
+ #endif
612
715
 
613
- void llama_attach_threadpool(
614
- struct llama_context * ctx,
615
- ggml_threadpool_t threadpool,
616
- ggml_threadpool_t threadpool_batch) {
617
- ctx->threadpool = threadpool;
618
- ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
716
+ return res;
619
717
  }
620
718
 
621
- void llama_detach_threadpool(struct llama_context * ctx) {
622
- ctx->threadpool = nullptr;
623
- ctx->threadpool_batch = nullptr;
624
- }
719
+ void llama_context::kv_self_update() {
720
+ auto & kv = kv_self;
625
721
 
626
- void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
627
- ctx->cparams.n_threads = n_threads;
628
- ctx->cparams.n_threads_batch = n_threads_batch;
629
- }
722
+ bool need_reserve = false;
630
723
 
631
- int32_t llama_n_threads(struct llama_context * ctx) {
632
- return ctx->cparams.n_threads;
633
- }
724
+ if (kv->has_shift) {
725
+ if (!kv->get_can_shift()) {
726
+ GGML_ABORT("The current context does not support K-shift");
727
+ }
634
728
 
635
- int32_t llama_n_threads_batch(struct llama_context * ctx) {
636
- return ctx->cparams.n_threads_batch;
637
- }
729
+ LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
638
730
 
639
- void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
640
- ctx->abort_callback = abort_callback;
641
- ctx->abort_callback_data = abort_callback_data;
731
+ // apply K-shift if needed
732
+ if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
733
+ ggml_backend_sched_reset(sched.get());
642
734
 
643
- for (auto & backend : ctx->backends) {
644
- auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
645
- auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
646
- if (set_abort_callback_fn) {
647
- set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
648
- }
649
- }
650
- }
735
+ auto * gf = graph_init();
651
736
 
652
- void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
653
- ctx->cparams.embeddings = embeddings;
654
- }
737
+ auto res = build_kv_self_shift(ctx_compute.get(), gf);
655
738
 
656
- void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
657
- ctx->cparams.causal_attn = causal_attn;
658
- }
739
+ ggml_backend_sched_alloc_graph(sched.get(), gf);
659
740
 
660
- void llama_synchronize(struct llama_context * ctx) {
661
- ggml_backend_sched_synchronize(ctx->sched.get());
741
+ res->set_inputs(nullptr);
662
742
 
663
- // FIXME: if multiple single tokens are evaluated without a synchronization,
664
- // the stats will be added to the prompt evaluation stats
665
- // this should only happen when using batch size 1 to evaluate a batch
743
+ graph_compute(gf, false);
666
744
 
667
- // add the evaluation to the stats
668
- if (ctx->n_queued_tokens == 1) {
669
- if (!ctx->cparams.no_perf) {
670
- ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
745
+ need_reserve = true;
671
746
  }
672
- ctx->n_eval++;
673
- } else if (ctx->n_queued_tokens > 1) {
674
- if (!ctx->cparams.no_perf) {
675
- ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
747
+
748
+ {
749
+ kv->has_shift = false;
750
+
751
+ for (uint32_t i = 0; i < kv->size; ++i) {
752
+ kv->cells[i].delta = 0;
753
+ }
676
754
  }
677
- ctx->n_p_eval += ctx->n_queued_tokens;
678
755
  }
679
756
 
680
- // get a more accurate load time, upon first eval
681
- if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
682
- ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
683
- ctx->has_evaluated_once = true;
684
- }
757
+ // defragment the KV cache if needed
758
+ if (kv->do_defrag) {
759
+ LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
685
760
 
686
- ctx->n_queued_tokens = 0;
687
- ctx->t_compute_start_us = 0;
688
- }
761
+ if (kv->defrag_prepare(graph_max_nodes())) {
762
+ ggml_backend_sched_reset(sched.get());
689
763
 
690
- float * llama_get_logits(struct llama_context * ctx) {
691
- llama_synchronize(ctx);
764
+ auto * gf = graph_init();
692
765
 
693
- // reorder logits for backward compatibility
694
- // TODO: maybe deprecate this
695
- llama_output_reorder(*ctx);
766
+ auto res = build_kv_self_defrag(ctx_compute.get(), gf);
696
767
 
697
- return ctx->logits;
698
- }
768
+ ggml_backend_sched_alloc_graph(sched.get(), gf);
699
769
 
700
- float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
701
- int32_t j = -1;
770
+ res->set_inputs(nullptr);
702
771
 
703
- llama_synchronize(ctx);
772
+ graph_compute(gf, false);
704
773
 
705
- try {
706
- if (ctx->logits == nullptr) {
707
- throw std::runtime_error("no logits");
774
+ need_reserve = true;
708
775
  }
709
776
 
710
- if (i < 0) {
711
- j = ctx->n_outputs + i;
712
- if (j < 0) {
713
- throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
777
+ kv->do_defrag = false;
778
+ }
779
+
780
+ // reserve a worst case graph if needed
781
+ if (need_reserve) {
782
+ LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
783
+
784
+ // build worst-case graph
785
+ uint32_t n_seqs = 1; // TODO: worst-case number of sequences
786
+ uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
787
+
788
+ // simulate full KV cache
789
+ kv_self->n = kv_self->size;
790
+
791
+ llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
792
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
793
+
794
+ auto * gf = graph_init();
795
+ graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
796
+
797
+ // initialize scheduler with the worst-case graph
798
+ ggml_backend_sched_reset(sched.get());
799
+ if (!ggml_backend_sched_reserve(sched.get(), gf)) {
800
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
801
+ }
802
+ }
803
+ }
804
+
805
+ enum llama_pooling_type llama_context::pooling_type() const {
806
+ return cparams.pooling_type;
807
+ }
808
+
809
+ float * llama_context::get_logits() {
810
+ // reorder logits for backward compatibility
811
+ output_reorder();
812
+
813
+ return logits;
814
+ }
815
+
816
+ float * llama_context::get_logits_ith(int32_t i) {
817
+ int32_t j = -1;
818
+
819
+ try {
820
+ if (logits == nullptr) {
821
+ throw std::runtime_error("no logits");
822
+ }
823
+
824
+ if (i < 0) {
825
+ j = n_outputs + i;
826
+ if (j < 0) {
827
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
714
828
  }
715
- } else if ((size_t) i >= ctx->output_ids.size()) {
716
- throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
829
+ } else if ((size_t) i >= output_ids.size()) {
830
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
717
831
  } else {
718
- j = ctx->output_ids[i];
832
+ j = output_ids[i];
719
833
  }
720
834
 
721
835
  if (j < 0) {
722
836
  throw std::runtime_error(format("batch.logits[%d] != true", i));
723
837
  }
724
- if (j >= ctx->n_outputs) {
838
+ if (j >= n_outputs) {
725
839
  // This should not happen
726
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
840
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
727
841
  }
728
842
 
729
- return ctx->logits + j*ctx->model.vocab.n_tokens();
843
+ return logits + j*model.vocab.n_tokens();
730
844
  } catch (const std::exception & err) {
731
845
  LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
732
846
  #ifndef NDEBUG
@@ -737,46 +851,41 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
737
851
  }
738
852
  }
739
853
 
740
- float * llama_get_embeddings(struct llama_context * ctx) {
741
- llama_synchronize(ctx);
742
-
854
+ float * llama_context::get_embeddings() {
743
855
  // reorder embeddings for backward compatibility
744
- // TODO: maybe deprecate this
745
- llama_output_reorder(*ctx);
856
+ output_reorder();
746
857
 
747
- return ctx->embd;
858
+ return embd;
748
859
  }
749
860
 
750
- float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
861
+ float * llama_context::get_embeddings_ith(int32_t i) {
751
862
  int32_t j = -1;
752
863
 
753
- llama_synchronize(ctx);
754
-
755
864
  try {
756
- if (ctx->embd == nullptr) {
865
+ if (embd == nullptr) {
757
866
  throw std::runtime_error("no embeddings");
758
867
  }
759
868
 
760
869
  if (i < 0) {
761
- j = ctx->n_outputs + i;
870
+ j = n_outputs + i;
762
871
  if (j < 0) {
763
- throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
872
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
764
873
  }
765
- } else if ((size_t) i >= ctx->output_ids.size()) {
766
- throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
874
+ } else if ((size_t) i >= output_ids.size()) {
875
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
767
876
  } else {
768
- j = ctx->output_ids[i];
877
+ j = output_ids[i];
769
878
  }
770
879
 
771
880
  if (j < 0) {
772
881
  throw std::runtime_error(format("batch.logits[%d] != true", i));
773
882
  }
774
- if (j >= ctx->n_outputs) {
883
+ if (j >= n_outputs) {
775
884
  // This should not happen
776
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
885
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
777
886
  }
778
887
 
779
- return ctx->embd + j*ctx->model.hparams.n_embd;
888
+ return embd + j*model.hparams.n_embd;
780
889
  } catch (const std::exception & err) {
781
890
  LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
782
891
  #ifndef NDEBUG
@@ -787,696 +896,943 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
787
896
  }
788
897
  }
789
898
 
790
- float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
791
- llama_synchronize(ctx);
792
-
793
- auto it = ctx->embd_seq.find(seq_id);
794
- if (it == ctx->embd_seq.end()) {
899
+ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
900
+ auto it = embd_seq.find(seq_id);
901
+ if (it == embd_seq.end()) {
795
902
  return nullptr;
796
903
  }
797
904
 
798
905
  return it->second.data();
799
906
  }
800
907
 
801
- // llama state API
908
+ void llama_context::attach_threadpool(
909
+ ggml_threadpool_t threadpool,
910
+ ggml_threadpool_t threadpool_batch) {
911
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
802
912
 
803
- // deprecated
804
- size_t llama_get_state_size(struct llama_context * ctx) {
805
- return llama_state_get_size(ctx);
913
+ this->threadpool = threadpool;
914
+ this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
806
915
  }
807
916
 
808
- // deprecated
809
- size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
810
- return llama_state_get_data(ctx, dst, -1);
811
- }
917
+ void llama_context::detach_threadpool() {
918
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
812
919
 
813
- // deprecated
814
- size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
815
- return llama_state_set_data(ctx, src, -1);
920
+ this->threadpool = nullptr;
921
+ this->threadpool_batch = nullptr;
816
922
  }
817
923
 
818
- // deprecated
819
- bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
820
- return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
821
- }
924
+ void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
925
+ LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
822
926
 
823
- // deprecated
824
- bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
825
- return llama_state_save_file(ctx, path_session, tokens, n_token_count);
927
+ cparams.n_threads = n_threads;
928
+ cparams.n_threads_batch = n_threads_batch;
826
929
  }
827
930
 
828
- // TODO: replace all non-fatal assertions with returned errors or exceptions
829
- struct llama_data_write {
830
- virtual void write(const void * src, size_t size) = 0;
831
- virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
832
- virtual size_t get_size_written() = 0;
833
- virtual ~llama_data_write() = default;
931
+ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
932
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
834
933
 
835
- void write_string(const std::string & str) {
836
- uint32_t str_size = str.size();
934
+ this->abort_callback = abort_callback;
935
+ this->abort_callback_data = abort_callback_data;
837
936
 
838
- write(&str_size, sizeof(str_size));
839
- write(str.data(), str_size);
937
+ for (auto & backend : backends) {
938
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
939
+ auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
940
+ if (set_abort_callback_fn) {
941
+ set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
942
+ }
840
943
  }
944
+ }
841
945
 
842
- void write_model_info(const struct llama_context * ctx) {
843
- const std::string arch_str = llm_arch_name(ctx->model.arch);
844
- write_string(arch_str);
845
- // TODO: add more model-specific info which should prevent loading the session file if not identical
846
- }
946
+ void llama_context::set_embeddings(bool value) {
947
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
847
948
 
848
- //void write_rng(const std::mt19937 & rng) {
849
- // std::ostringstream rng_ss;
850
- // rng_ss << rng;
949
+ cparams.embeddings = value;
950
+ }
851
951
 
852
- // const std::string & rng_str = rng_ss.str();
952
+ void llama_context::set_causal_attn(bool value) {
953
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
853
954
 
854
- // write_string(rng_str);
855
- //}
955
+ cparams.causal_attn = value;
956
+ }
856
957
 
857
- void write_output_ids(struct llama_context * ctx) {
858
- llama_output_reorder(*ctx);
958
+ void llama_context::set_warmup(bool value) {
959
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
859
960
 
860
- const uint32_t n_outputs = ctx->n_outputs;
961
+ cparams.warmup = value;
962
+ }
861
963
 
862
- std::vector<int32_t> output_pos;
964
+ void llama_context::set_adapter_lora(
965
+ llama_adapter_lora * adapter,
966
+ float scale) {
967
+ LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
863
968
 
864
- const size_t n_batch = ctx->cparams.n_batch;
865
- const auto & output_ids = ctx->output_ids;
969
+ loras[adapter] = scale;
970
+ }
866
971
 
867
- GGML_ASSERT(n_outputs <= ctx->output_size);
972
+ bool llama_context::rm_adapter_lora(
973
+ llama_adapter_lora * adapter) {
974
+ LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
868
975
 
869
- output_pos.resize(n_outputs);
976
+ auto pos = loras.find(adapter);
977
+ if (pos != loras.end()) {
978
+ loras.erase(pos);
979
+ return true;
980
+ }
870
981
 
871
- // build a more compact representation of the output ids
872
- for (size_t i = 0; i < n_batch; ++i) {
873
- // map an output id to a position in the batch
874
- int32_t pos = output_ids[i];
875
- if (pos >= 0) {
876
- GGML_ASSERT((uint32_t) pos < n_outputs);
877
- output_pos[pos] = i;
878
- }
879
- }
982
+ return false;
983
+ }
880
984
 
881
- write(&n_outputs, sizeof(n_outputs));
985
+ void llama_context::clear_adapter_lora() {
986
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
882
987
 
883
- if (n_outputs) {
884
- write(output_pos.data(), n_outputs * sizeof(int32_t));
885
- }
886
- }
988
+ loras.clear();
989
+ }
887
990
 
888
- void write_logits(const struct llama_context * ctx) {
889
- const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
991
+ bool llama_context::apply_adapter_cvec(
992
+ const float * data,
993
+ size_t len,
994
+ int32_t n_embd,
995
+ int32_t il_start,
996
+ int32_t il_end) {
997
+ LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
890
998
 
891
- write(&logits_size, sizeof(logits_size));
999
+ return cvec.apply(model, data, len, n_embd, il_start, il_end);
1000
+ }
892
1001
 
893
- if (logits_size) {
894
- write(ctx->logits, logits_size * sizeof(float));
895
- }
1002
+ int llama_context::encode(llama_batch & inp_batch) {
1003
+ if (inp_batch.n_tokens == 0) {
1004
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1005
+ return -1;
896
1006
  }
897
1007
 
898
- void write_embeddings(const struct llama_context * ctx) {
899
- const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
900
-
901
- write(&embeddings_size, sizeof(embeddings_size));
1008
+ // temporary allocate memory for the input batch if needed
1009
+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1010
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
902
1011
 
903
- if (embeddings_size) {
904
- write(ctx->embd, embeddings_size * sizeof(float));
905
- }
906
- }
1012
+ const llama_batch & batch = batch_allocr.batch;
1013
+ const int32_t n_tokens = batch.n_tokens;
907
1014
 
908
- void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
909
- for (const auto & range : cell_ranges) {
910
- for (uint32_t i = range.first; i < range.second; ++i) {
911
- const auto & cell = kv_self.cells[i];
912
- const llama_pos pos = cell.pos;
913
- const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
1015
+ const auto & hparams = model.hparams;
914
1016
 
915
- write(&pos, sizeof(pos));
916
- write(&n_seq_id, sizeof(n_seq_id));
1017
+ GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
917
1018
 
918
- if (n_seq_id) {
919
- for (auto seq_id : cell.seq_id) {
920
- write(&seq_id, sizeof(seq_id));
921
- }
922
- }
1019
+ if (batch.token) {
1020
+ for (int32_t i = 0; i < n_tokens; ++i) {
1021
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
1022
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
1023
+ return -1;
923
1024
  }
924
1025
  }
925
1026
  }
926
1027
 
927
- void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
928
- const struct llama_kv_cache & kv_self = ctx->kv_self;
929
- const struct llama_hparams & hparams = ctx->model.hparams;
1028
+ // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
1029
+ GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
930
1030
 
931
- const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
932
- const uint32_t n_layer = hparams.n_layer;
1031
+ if (t_compute_start_us == 0) {
1032
+ t_compute_start_us = ggml_time_us();
1033
+ }
933
1034
 
934
- write(&v_trans, sizeof(v_trans));
935
- write(&n_layer, sizeof(n_layer));
1035
+ n_queued_tokens += n_tokens;
936
1036
 
937
- std::vector<uint8_t> tmp_buf;
1037
+ const int64_t n_embd = hparams.n_embd;
938
1038
 
939
- // Iterate and write all the keys first, each row is a cell
940
- // Get whole range at a time
941
- for (uint32_t il = 0; il < n_layer; ++il) {
942
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1039
+ sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
943
1040
 
944
- // Write key type
945
- const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
946
- write(&k_type_i, sizeof(k_type_i));
1041
+ const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
947
1042
 
948
- // Write row size of key
949
- const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
950
- write(&k_size_row, sizeof(k_size_row));
1043
+ // reserve output buffer
1044
+ if (output_reserve(n_tokens) < n_tokens) {
1045
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
1046
+ return -2;
1047
+ };
951
1048
 
952
- // Read each range of cells of k_size length each into tmp_buf and write out
953
- for (const auto & range : cell_ranges) {
954
- const size_t range_size = range.second - range.first;
955
- const size_t buf_size = range_size * k_size_row;
956
- write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
957
- }
958
- }
1049
+ for (int32_t i = 0; i < n_tokens; ++i) {
1050
+ output_ids[i] = i;
1051
+ }
959
1052
 
960
- if (!kv_self.v_trans) {
961
- for (uint32_t il = 0; il < n_layer; ++il) {
962
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1053
+ n_outputs = n_tokens;
963
1054
 
964
- // Write value type
965
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
966
- write(&v_type_i, sizeof(v_type_i));
1055
+ //batch_manager->prepare(ubatch);
967
1056
 
968
- // Write row size of value
969
- const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
970
- write(&v_size_row, sizeof(v_size_row));
1057
+ ggml_backend_sched_reset(sched.get());
1058
+ ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
971
1059
 
972
- // Read each range of cells of v_size length each into tmp_buf and write out
973
- for (const auto & range : cell_ranges) {
974
- const size_t range_size = range.second - range.first;
975
- const size_t buf_size = range_size * v_size_row;
976
- write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
977
- }
978
- }
979
- } else {
980
- // When v is transposed, we also need the element size and get the element ranges from each row
981
- const uint32_t kv_size = kv_self.size;
982
- for (uint32_t il = 0; il < n_layer; ++il) {
983
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1060
+ const auto causal_attn_org = cparams.causal_attn;
984
1061
 
985
- // Write value type
986
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
987
- write(&v_type_i, sizeof(v_type_i));
1062
+ // always use non-causal attention for encoder graphs
1063
+ // TODO: this is a tmp solution until we have a proper way to support enc-dec models
1064
+ // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
1065
+ cparams.causal_attn = false;
988
1066
 
989
- // Write element size
990
- const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
991
- write(&v_size_el, sizeof(v_size_el));
1067
+ auto * gf = graph_init();
1068
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
992
1069
 
993
- // Write GQA embedding size
994
- write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
1070
+ ggml_backend_sched_alloc_graph(sched.get(), gf);
995
1071
 
996
- // For each row, we get the element values of each cell
997
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
998
- // Read each range of cells of v_size_el length each into tmp_buf and write out
999
- for (const auto & range : cell_ranges) {
1000
- const size_t range_size = range.second - range.first;
1001
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1002
- const size_t buf_size = range_size * v_size_el;
1003
- write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
1004
- }
1005
- }
1006
- }
1007
- }
1072
+ res->set_inputs(&ubatch);
1073
+
1074
+ cparams.causal_attn = causal_attn_org;
1075
+
1076
+ const auto compute_status = graph_compute(gf, n_tokens > 1);
1077
+ switch (compute_status) {
1078
+ case GGML_STATUS_SUCCESS:
1079
+ break;
1080
+ case GGML_STATUS_ABORTED:
1081
+ return 2;
1082
+ case GGML_STATUS_ALLOC_FAILED:
1083
+ return -2;
1084
+ case GGML_STATUS_FAILED:
1085
+ default:
1086
+ return -3;
1008
1087
  }
1009
1088
 
1010
- void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
1011
- const struct llama_kv_cache & kv_self = ctx->kv_self;
1012
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1013
- uint32_t cell_count = 0;
1089
+ auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
1014
1090
 
1015
- // Count the number of cells with the specified seq_id
1016
- // Find all the ranges of cells with this seq id (or all, when -1)
1017
- uint32_t cell_range_begin = kv_self.size;
1018
- for (uint32_t i = 0; i < kv_self.size; ++i) {
1019
- const auto & cell = kv_self.cells[i];
1020
- if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
1021
- ++cell_count;
1022
- if (cell_range_begin == kv_self.size) {
1023
- cell_range_begin = i;
1091
+ // extract embeddings
1092
+ if (t_embd) {
1093
+ ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1094
+ GGML_ASSERT(backend_embd != nullptr);
1095
+
1096
+ GGML_ASSERT(embd != nullptr);
1097
+
1098
+ switch (cparams.pooling_type) {
1099
+ case LLAMA_POOLING_TYPE_NONE:
1100
+ {
1101
+ // extract token embeddings
1102
+ GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1103
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1104
+ } break;
1105
+ case LLAMA_POOLING_TYPE_MEAN:
1106
+ case LLAMA_POOLING_TYPE_CLS:
1107
+ case LLAMA_POOLING_TYPE_LAST:
1108
+ {
1109
+ // extract sequence embeddings
1110
+ auto & embd_seq_out = embd_seq;
1111
+ embd_seq_out.clear();
1112
+
1113
+ GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
1114
+
1115
+ for (int32_t i = 0; i < n_tokens; i++) {
1116
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
1117
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1118
+ continue;
1119
+ }
1120
+ embd_seq_out[seq_id].resize(n_embd);
1121
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1122
+ }
1123
+ } break;
1124
+ case LLAMA_POOLING_TYPE_RANK:
1125
+ {
1126
+ // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
1127
+ // wait for an encoder model that requires this pooling type in order to test it
1128
+ // https://github.com/ggerganov/llama.cpp/pull/9510
1129
+ GGML_ABORT("RANK pooling not implemented yet");
1024
1130
  }
1025
- } else {
1026
- if (cell_range_begin != kv_self.size) {
1027
- cell_ranges.emplace_back(cell_range_begin, i);
1028
- cell_range_begin = kv_self.size;
1131
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
1132
+ {
1133
+ GGML_ABORT("unknown pooling type");
1029
1134
  }
1030
- }
1031
1135
  }
1032
- if (cell_range_begin != kv_self.size) {
1033
- cell_ranges.emplace_back(cell_range_begin, kv_self.size);
1034
- }
1035
-
1036
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1037
- uint32_t cell_count_check = 0;
1038
- for (const auto & range : cell_ranges) {
1039
- cell_count_check += range.second - range.first;
1040
- }
1041
- GGML_ASSERT(cell_count == cell_count_check);
1136
+ }
1042
1137
 
1043
- write(&cell_count, sizeof(cell_count));
1138
+ // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1139
+ // overlap with device computation.
1140
+ ggml_backend_sched_reset(sched.get());
1044
1141
 
1045
- write_kv_cache_meta(kv_self, cell_ranges, seq_id);
1046
- write_kv_cache_data(ctx, cell_ranges);
1047
- }
1048
- };
1142
+ // TODO: hacky solution
1143
+ if (model.arch == LLM_ARCH_T5 && t_embd) {
1144
+ //cross.t_embd = t_embd;
1049
1145
 
1050
- struct llama_data_read {
1051
- virtual const uint8_t * read(size_t size) = 0;
1052
- virtual void read_to(void * dst, size_t size) = 0;
1053
- virtual size_t get_size_read() = 0;
1054
- virtual ~llama_data_read() = default;
1146
+ synchronize();
1055
1147
 
1056
- void read_string(std::string & str) {
1057
- uint32_t str_size;
1058
- read_to(&str_size, sizeof(str_size));
1148
+ cross.n_embd = t_embd->ne[0];
1149
+ cross.n_enc = t_embd->ne[1];
1150
+ cross.v_embd.resize(cross.n_embd*cross.n_enc);
1151
+ memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
1059
1152
 
1060
- str.assign((const char *) read(str_size), str_size);
1153
+ // remember the sequence ids used during the encoding - needed for cross attention later
1154
+ cross.seq_ids_enc.resize(n_tokens);
1155
+ for (int32_t i = 0; i < n_tokens; i++) {
1156
+ cross.seq_ids_enc[i].clear();
1157
+ for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
1158
+ llama_seq_id seq_id = ubatch.seq_id[i][s];
1159
+ cross.seq_ids_enc[i].insert(seq_id);
1160
+ }
1161
+ }
1061
1162
  }
1062
1163
 
1063
- // validate model information
1064
- void read_model_info(const struct llama_context * ctx) {
1065
- const std::string cur_arch_str = llm_arch_name(ctx->model.arch);
1164
+ return 0;
1165
+ }
1066
1166
 
1067
- std::string arch_str;
1068
- read_string(arch_str);
1069
- if (cur_arch_str != arch_str) {
1070
- throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
1071
- }
1072
- // TODO: add more info which needs to be identical but which is not verified otherwise
1167
+ int llama_context::decode(llama_batch & inp_batch) {
1168
+ if (inp_batch.n_tokens == 0) {
1169
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1170
+ return -1;
1073
1171
  }
1074
1172
 
1075
- //void read_rng(std::mt19937 & rng) {
1076
- // std::string rng_str;
1077
- // read_string(rng_str);
1173
+ // temporary allocate memory for the input batch if needed
1174
+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1175
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1078
1176
 
1079
- // std::istringstream rng_ss(rng_str);
1080
- // rng_ss >> rng;
1177
+ const llama_batch & batch = batch_allocr.batch;
1081
1178
 
1082
- // if (rng_ss.fail()) {
1083
- // throw std::runtime_error("failed to load RNG state");
1084
- // }
1085
- //}
1179
+ const auto & vocab = model.vocab;
1180
+ const auto & hparams = model.hparams;
1086
1181
 
1087
- void read_output_ids(struct llama_context * ctx) {
1088
- std::vector<int32_t> output_pos;
1182
+ const int32_t n_vocab = vocab.n_tokens();
1089
1183
 
1090
- uint32_t n_outputs;
1091
- read_to(&n_outputs, sizeof(n_outputs));
1184
+ const int64_t n_tokens_all = batch.n_tokens;
1185
+ const int64_t n_embd = hparams.n_embd;
1092
1186
 
1093
- if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
1094
- throw std::runtime_error("could not reserve outputs");
1187
+ // TODO: remove this stuff
1188
+ class batch_guard {
1189
+ public:
1190
+ batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
1095
1191
  }
1096
1192
 
1097
- if (n_outputs) {
1098
- output_pos.resize(n_outputs);
1099
- read_to(output_pos.data(), n_outputs * sizeof(int32_t));
1100
-
1101
- for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
1102
- int32_t id = output_pos[i];
1103
- if ((uint32_t) id >= ctx->cparams.n_batch) {
1104
- throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
1105
- }
1106
- ctx->output_ids[id] = i;
1193
+ ~batch_guard() {
1194
+ if (!is_done) {
1195
+ kv_slot_restorer.restore();
1107
1196
  }
1108
-
1109
- ctx->n_outputs = n_outputs;
1110
1197
  }
1111
- }
1112
1198
 
1113
- void read_logits(struct llama_context * ctx) {
1114
- uint64_t logits_size;
1115
- read_to(&logits_size, sizeof(logits_size));
1116
-
1117
- if (ctx->logits_size < logits_size) {
1118
- throw std::runtime_error("logits buffer too small");
1199
+ void done() {
1200
+ is_done = true;
1119
1201
  }
1120
1202
 
1121
- if (logits_size) {
1122
- read_to(ctx->logits, logits_size * sizeof(float));
1203
+ void save(const llama_kv_cache_slot_info & slot_info) {
1204
+ kv_slot_restorer.save(slot_info);
1123
1205
  }
1124
- }
1125
1206
 
1126
- void read_embeddings(struct llama_context * ctx) {
1127
- uint64_t embeddings_size;
1128
- read_to(&embeddings_size, sizeof(embeddings_size));
1207
+ private:
1208
+ bool is_done = false;
1129
1209
 
1130
- if (ctx->embd_size < embeddings_size) {
1131
- throw std::runtime_error("embeddings buffer too small");
1132
- }
1210
+ llama_kv_slot_restorer kv_slot_restorer;
1211
+ };
1212
+
1213
+ batch_guard bg(*kv_self);
1133
1214
 
1134
- if (embeddings_size) {
1135
- read_to(ctx->embd, embeddings_size * sizeof(float));
1215
+ GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1216
+
1217
+ if (batch.token) {
1218
+ for (int64_t i = 0; i < n_tokens_all; ++i) {
1219
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
1220
+ LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
1221
+ throw std::runtime_error("invalid token");
1222
+ }
1136
1223
  }
1137
1224
  }
1138
1225
 
1139
- bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
1140
- struct llama_kv_cache & kv_self = ctx->kv_self;
1226
+ GGML_ASSERT(n_tokens_all <= cparams.n_batch);
1141
1227
 
1142
- if (dest_seq_id != -1) {
1143
- // single sequence
1228
+ GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
1144
1229
 
1145
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
1230
+ if (t_compute_start_us == 0) {
1231
+ t_compute_start_us = ggml_time_us();
1232
+ }
1233
+ n_queued_tokens += n_tokens_all;
1146
1234
 
1147
- llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1148
- batch.n_tokens = cell_count;
1149
- batch.n_seq_tokens = cell_count;
1150
- batch.n_seqs = 1;
1235
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1236
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1151
1237
 
1152
- for (uint32_t i = 0; i < cell_count; ++i) {
1153
- llama_pos pos;
1154
- uint32_t n_seq_id;
1238
+ embd_seq.clear();
1155
1239
 
1156
- read_to(&pos, sizeof(pos));
1157
- read_to(&n_seq_id, sizeof(n_seq_id));
1240
+ int64_t n_outputs_all = 0;
1158
1241
 
1159
- if (n_seq_id != 0) {
1160
- LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1161
- return false;
1162
- }
1242
+ // count outputs
1243
+ if (batch.logits && !embd_pooled) {
1244
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
1245
+ n_outputs_all += batch.logits[i] != 0;
1246
+ }
1247
+ } else if (logits_all || embd_pooled) {
1248
+ n_outputs_all = n_tokens_all;
1249
+ } else {
1250
+ // keep last output only
1251
+ n_outputs_all = 1;
1252
+ }
1163
1253
 
1164
- batch.pos[i] = pos;
1165
- }
1166
- batch.n_seq_id[0] = 1;
1167
- batch.seq_id[0] = &dest_seq_id;
1168
- if (!llama_kv_cache_find_slot(kv_self, batch)) {
1169
- LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1170
- return false;
1171
- }
1254
+ const bool logits_all = n_outputs_all == n_tokens_all;
1172
1255
 
1173
- // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1174
- // Assume that this is one contiguous block of cells
1175
- GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
1176
- GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
1177
- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
1178
- GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
1179
- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
1180
- } else {
1181
- // whole KV cache restore
1256
+ sbatch.from_batch(batch, n_embd,
1257
+ /* simple_split */ !kv_self->recurrent,
1258
+ /* logits_all */ logits_all);
1182
1259
 
1183
- if (cell_count > kv_self.size) {
1184
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1185
- return false;
1186
- }
1260
+ // reserve output buffer
1261
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1262
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1263
+ return -2;
1264
+ };
1265
+
1266
+ int64_t n_outputs_prev = 0;
1187
1267
 
1188
- llama_kv_cache_clear(kv_self);
1268
+ while (sbatch.n_tokens > 0) {
1269
+ llama_ubatch ubatch = llama_ubatch();
1189
1270
 
1190
- for (uint32_t i = 0; i < cell_count; ++i) {
1191
- llama_kv_cell & cell = kv_self.cells[i];
1271
+ const auto & n_ubatch = cparams.n_ubatch;
1192
1272
 
1193
- llama_pos pos;
1194
- uint32_t n_seq_id;
1273
+ if (kv_self->recurrent) {
1274
+ if (embd_pooled) {
1275
+ // Pooled embeddings cannot be split across ubatches (yet)
1276
+ ubatch = sbatch.split_seq(cparams.n_ubatch);
1277
+ } else {
1278
+ // recurrent model architectures are easier to implement
1279
+ // with equal-length sequences
1280
+ ubatch = sbatch.split_equal(cparams.n_ubatch);
1281
+ }
1282
+ } else {
1283
+ ubatch = sbatch.split_simple(n_ubatch);
1284
+ }
1195
1285
 
1196
- read_to(&pos, sizeof(pos));
1197
- read_to(&n_seq_id, sizeof(n_seq_id));
1286
+ // count the outputs in this u_batch
1287
+ {
1288
+ int32_t n_outputs_new = 0;
1198
1289
 
1199
- cell.pos = pos;
1290
+ if (n_outputs_all == n_tokens_all) {
1291
+ n_outputs_new = ubatch.n_tokens;
1292
+ } else {
1293
+ GGML_ASSERT(ubatch.output);
1294
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1295
+ n_outputs_new += (int32_t) (ubatch.output[i] != 0);
1296
+ }
1297
+ }
1200
1298
 
1201
- for (uint32_t j = 0; j < n_seq_id; ++j) {
1202
- llama_seq_id seq_id;
1203
- read_to(&seq_id, sizeof(seq_id));
1299
+ // needs to happen before the graph is built
1300
+ n_outputs = n_outputs_new;
1301
+ }
1204
1302
 
1205
- if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1206
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1207
- return false;
1208
- }
1303
+ // non-causal masks do not use the KV cache
1304
+ if (hparams.causal_attn) {
1305
+ kv_self_update();
1209
1306
 
1210
- cell.seq_id.insert(seq_id);
1307
+ // if we have enough unused cells before the current head ->
1308
+ // better to start searching from the beginning of the cache, hoping to fill it
1309
+ if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
1310
+ kv_self->head = 0;
1311
+ }
1211
1312
 
1212
- if (kv_self.recurrent) {
1213
- int32_t & tail = kv_self.cells[seq_id].tail;
1214
- if (tail != -1) {
1215
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
1216
- return false;
1217
- }
1218
- tail = i;
1219
- }
1220
- }
1313
+ const auto slot_info = kv_self->find_slot(ubatch);
1314
+ if (!slot_info) {
1315
+ LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
1316
+ return -3;
1221
1317
  }
1222
1318
 
1223
- kv_self.head = 0;
1224
- kv_self.used = cell_count;
1225
- }
1319
+ bg.save(slot_info);
1226
1320
 
1227
- if (kv_self.recurrent) {
1228
- for (uint32_t i = 0; i < cell_count; ++i) {
1229
- uint32_t cell_id = kv_self.head + i;
1230
- // make sure the recurrent states will keep their restored state
1231
- kv_self.cells[cell_id].src = cell_id;
1321
+ if (!kv_self->recurrent) {
1322
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
1323
+ // after enough generations, the benefit from this heuristic disappears
1324
+ // if we start defragmenting the cache, the benefit from this will be more important
1325
+ const uint32_t pad = kv_self->get_padding(cparams);
1326
+ kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
1232
1327
  }
1233
1328
  }
1234
1329
 
1235
- return true;
1236
- }
1330
+ //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1237
1331
 
1238
- bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
1239
- const struct llama_hparams & hparams = ctx->model.hparams;
1240
- struct llama_kv_cache & kv_self = ctx->kv_self;
1241
- uint32_t v_trans;
1242
- uint32_t n_layer;
1243
- read_to(&v_trans, sizeof(v_trans));
1244
- read_to(&n_layer, sizeof(n_layer));
1332
+ ggml_backend_sched_reset(sched.get());
1333
+ ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1245
1334
 
1246
- if (n_layer != hparams.n_layer) {
1247
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1248
- return false;
1249
- }
1250
- if (cell_count > kv_self.size) {
1251
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
1252
- return false;
1253
- }
1254
- if (kv_self.v_trans != (bool) v_trans) {
1255
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1256
- return false;
1257
- }
1335
+ auto * gf = graph_init();
1336
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1258
1337
 
1259
- // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1260
- for (uint32_t il = 0; il < n_layer; ++il) {
1261
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1338
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1262
1339
 
1263
- // Read type of key
1264
- int32_t k_type_i_ref;
1265
- read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1266
- const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
1267
- if (k_type_i != k_type_i_ref) {
1268
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1269
- return false;
1270
- }
1340
+ ggml_backend_sched_alloc_graph(sched.get(), gf);
1341
+
1342
+ res->set_inputs(&ubatch);
1271
1343
 
1272
- // Read row size of key
1273
- uint64_t k_size_row_ref;
1274
- read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1275
- const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
1276
- if (k_size_row != k_size_row_ref) {
1277
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1278
- return false;
1344
+ const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
1345
+ if (compute_status != GGML_STATUS_SUCCESS) {
1346
+ switch (compute_status) {
1347
+ case GGML_STATUS_ABORTED:
1348
+ return 2;
1349
+ case GGML_STATUS_ALLOC_FAILED:
1350
+ return -2;
1351
+ case GGML_STATUS_FAILED:
1352
+ default:
1353
+ return -3;
1279
1354
  }
1355
+ }
1356
+
1357
+ // update the kv ring buffer
1358
+ {
1359
+ kv_self->head += ubatch.n_tokens;
1280
1360
 
1281
- if (cell_count) {
1282
- // Read and set the keys for the whole cell range
1283
- ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
1361
+ // Ensure kv cache head points to a valid index.
1362
+ if (kv_self->head >= kv_self->size) {
1363
+ kv_self->head = 0;
1284
1364
  }
1285
1365
  }
1286
1366
 
1287
- if (!kv_self.v_trans) {
1288
- for (uint32_t il = 0; il < n_layer; ++il) {
1289
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1367
+ // plot the computation graph in dot format (for debugging purposes)
1368
+ //if (n_past%100 == 0) {
1369
+ // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1370
+ //}
1290
1371
 
1291
- // Read type of value
1292
- int32_t v_type_i_ref;
1293
- read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1294
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1295
- if (v_type_i != v_type_i_ref) {
1296
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1297
- return false;
1298
- }
1372
+ auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1373
+ auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1299
1374
 
1300
- // Read row size of value
1301
- uint64_t v_size_row_ref;
1302
- read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1303
- const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
1304
- if (v_size_row != v_size_row_ref) {
1305
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1306
- return false;
1307
- }
1375
+ if (t_embd && res->get_embd_pooled()) {
1376
+ t_embd = res->get_embd_pooled();
1377
+ }
1308
1378
 
1309
- if (cell_count) {
1310
- // Read and set the values for the whole cell range
1311
- ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
1312
- }
1313
- }
1314
- } else {
1315
- // For each layer, read the values for each cell (transposed)
1316
- for (uint32_t il = 0; il < n_layer; ++il) {
1317
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1318
-
1319
- // Read type of value
1320
- int32_t v_type_i_ref;
1321
- read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1322
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1323
- if (v_type_i != v_type_i_ref) {
1324
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1325
- return false;
1326
- }
1379
+ // extract logits
1380
+ if (t_logits && n_outputs > 0) {
1381
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1382
+ GGML_ASSERT(backend_res != nullptr);
1383
+ GGML_ASSERT(logits != nullptr);
1327
1384
 
1328
- // Read element size of value
1329
- uint32_t v_size_el_ref;
1330
- read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1331
- const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
1332
- if (v_size_el != v_size_el_ref) {
1333
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1334
- return false;
1335
- }
1385
+ float * logits_out = logits + n_outputs_prev*n_vocab;
1336
1386
 
1337
- // Read GQA embedding size
1338
- uint32_t n_embd_v_gqa_ref;
1339
- read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1340
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1341
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1342
- return false;
1343
- }
1387
+ if (n_outputs) {
1388
+ GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1389
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1390
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1391
+ }
1392
+ }
1344
1393
 
1345
- if (cell_count) {
1346
- // For each row in the transposed matrix, read the values for the whole cell range
1347
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1348
- const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
1349
- ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1394
+ // extract embeddings
1395
+ if (t_embd && n_outputs > 0) {
1396
+ ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1397
+ GGML_ASSERT(backend_embd != nullptr);
1398
+
1399
+ switch (cparams.pooling_type) {
1400
+ case LLAMA_POOLING_TYPE_NONE:
1401
+ {
1402
+ // extract token embeddings
1403
+ GGML_ASSERT(embd != nullptr);
1404
+ float * embd_out = embd + n_outputs_prev*n_embd;
1405
+
1406
+ if (n_outputs) {
1407
+ GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1408
+ GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1409
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
1410
+ }
1411
+ } break;
1412
+ case LLAMA_POOLING_TYPE_MEAN:
1413
+ case LLAMA_POOLING_TYPE_CLS:
1414
+ case LLAMA_POOLING_TYPE_LAST:
1415
+ {
1416
+ // extract sequence embeddings (cleared before processing each batch)
1417
+ auto & embd_seq_out = embd_seq;
1418
+
1419
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1420
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
1421
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1422
+ continue;
1423
+ }
1424
+ embd_seq_out[seq_id].resize(n_embd);
1425
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1426
+ }
1427
+ } break;
1428
+ case LLAMA_POOLING_TYPE_RANK:
1429
+ {
1430
+ // extract the rerank score - a single float per sequence
1431
+ auto & embd_seq_out = embd_seq;
1432
+
1433
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1434
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
1435
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1436
+ continue;
1437
+ }
1438
+ embd_seq_out[seq_id].resize(1);
1439
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
1440
+ }
1441
+ } break;
1442
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
1443
+ {
1444
+ GGML_ABORT("unknown pooling type");
1350
1445
  }
1351
- }
1352
1446
  }
1353
1447
  }
1354
- return true;
1448
+
1449
+ n_outputs_prev += n_outputs;
1355
1450
  }
1356
1451
 
1357
- void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
1358
- uint32_t cell_count;
1359
- read_to(&cell_count, sizeof(cell_count));
1452
+ // finalize the batch processing
1453
+ bg.done();
1454
+
1455
+ // set output mappings
1456
+ {
1457
+ bool sorted_output = true;
1360
1458
 
1361
- bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
1459
+ GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1362
1460
 
1363
- if (!res) {
1364
- if (seq_id == -1) {
1365
- llama_kv_cache_clear(ctx);
1366
- } else {
1367
- llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
1461
+ for (int64_t i = 0; i < n_outputs_all; ++i) {
1462
+ int64_t out_id = sbatch.out_ids[i];
1463
+ output_ids[out_id] = i;
1464
+ if (out_id != i) {
1465
+ sorted_output = false;
1368
1466
  }
1369
- throw std::runtime_error("failed to restore kv cache");
1370
1467
  }
1371
- }
1372
- };
1373
-
1374
- struct llama_data_write_dummy : llama_data_write {
1375
- size_t size_written = 0;
1376
1468
 
1377
- llama_data_write_dummy() {}
1378
-
1379
- void write(const void * /* src */, size_t size) override {
1380
- size_written += size;
1469
+ if (sorted_output) {
1470
+ sbatch.out_ids.clear();
1471
+ }
1381
1472
  }
1382
1473
 
1383
- void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
1384
- size_written += size;
1385
- }
1474
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
1475
+ n_outputs = n_outputs_all;
1386
1476
 
1387
- size_t get_size_written() override {
1388
- return size_written;
1389
- }
1390
- };
1477
+ // wait for the computation to finish (automatically done when obtaining the model output)
1478
+ //synchronize();
1391
1479
 
1392
- struct llama_data_write_buffer : llama_data_write {
1393
- uint8_t * ptr;
1394
- size_t buf_size = 0;
1395
- size_t size_written = 0;
1480
+ // decide if we need to defrag the kv cache
1481
+ if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1482
+ // - do not defrag small contexts (i.e. < 2048 tokens)
1483
+ // - count the padding towards the number of used tokens
1484
+ const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1396
1485
 
1397
- llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1486
+ // queue defragmentation for next llama_kv_cache_update
1487
+ if (fragmentation > cparams.defrag_thold) {
1488
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1398
1489
 
1399
- void write(const void * src, size_t size) override {
1400
- if (size > buf_size) {
1401
- throw std::runtime_error("unexpectedly reached end of buffer");
1490
+ kv_self->defrag();
1402
1491
  }
1403
- memcpy(ptr, src, size);
1404
- ptr += size;
1405
- size_written += size;
1406
- buf_size -= size;
1407
1492
  }
1408
1493
 
1409
- void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
1410
- if (size > buf_size) {
1411
- throw std::runtime_error("unexpectedly reached end of buffer");
1412
- }
1413
- ggml_backend_tensor_get(tensor, ptr, offset, size);
1414
- ptr += size;
1415
- size_written += size;
1416
- buf_size -= size;
1417
- }
1494
+ // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1495
+ // overlap with device computation.
1496
+ ggml_backend_sched_reset(sched.get());
1418
1497
 
1419
- size_t get_size_written() override {
1420
- return size_written;
1421
- }
1422
- };
1498
+ return 0;
1499
+ }
1423
1500
 
1424
- struct llama_data_read_buffer : llama_data_read {
1425
- const uint8_t * ptr;
1426
- size_t buf_size = 0;
1427
- size_t size_read = 0;
1501
+ //
1502
+ // output
1503
+ //
1428
1504
 
1429
- llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1505
+ int32_t llama_context::output_reserve(int32_t n_outputs) {
1506
+ const auto & hparams = model.hparams;
1507
+ const auto & vocab = model.vocab;
1430
1508
 
1431
- const uint8_t * read(size_t size) override {
1432
- const uint8_t * base_ptr = ptr;
1433
- if (size > buf_size) {
1434
- throw std::runtime_error("unexpectedly reached end of buffer");
1435
- }
1436
- ptr += size;
1437
- size_read += size;
1438
- buf_size -= size;
1439
- return base_ptr;
1440
- }
1509
+ const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
1441
1510
 
1442
- void read_to(void * dst, size_t size) override {
1443
- memcpy(dst, read(size), size);
1444
- }
1511
+ const auto n_batch = cparams.n_batch;
1512
+ const auto n_vocab = vocab.n_tokens();
1513
+ const auto n_embd = hparams.n_embd;
1445
1514
 
1446
- size_t get_size_read() override {
1447
- return size_read;
1448
- }
1449
- };
1515
+ // TODO: use a per-batch flag for logits presence instead
1516
+ bool has_logits = !cparams.embeddings;
1517
+ bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1450
1518
 
1451
- struct llama_data_write_file : llama_data_write {
1452
- llama_file * file;
1453
- size_t size_written = 0;
1454
- std::vector<uint8_t> temp_buffer;
1519
+ // TODO: hacky enc-dec support
1520
+ if (model.arch == LLM_ARCH_T5) {
1521
+ has_logits = true;
1522
+ has_embd = true;
1523
+ }
1455
1524
 
1456
- llama_data_write_file(llama_file * f) : file(f) {}
1525
+ logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1526
+ embd_size = has_embd ? n_embd*n_outputs_max : 0;
1457
1527
 
1458
- void write(const void * src, size_t size) override {
1459
- file->write_raw(src, size);
1460
- size_written += size;
1528
+ if (output_ids.empty()) {
1529
+ // init, never resized afterwards
1530
+ output_ids.resize(n_batch);
1461
1531
  }
1462
1532
 
1463
- void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
1533
+ const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
1534
+ const size_t new_size = (logits_size + embd_size) * sizeof(float);
1535
+
1536
+ // alloc only when more than the current capacity is required
1537
+ // TODO: also consider shrinking the buffer
1538
+ if (!buf_output || prev_size < new_size) {
1539
+ if (buf_output) {
1540
+ #ifndef NDEBUG
1541
+ // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1542
+ LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1543
+ #endif
1544
+ buf_output = nullptr;
1545
+ logits = nullptr;
1546
+ embd = nullptr;
1547
+ }
1548
+
1549
+ auto * buft = ggml_backend_cpu_buffer_type();
1550
+ // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
1551
+ auto * output_dev = model.dev_output();
1552
+ auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
1553
+ if (output_dev_host_buft) {
1554
+ buft = output_dev_host_buft;
1555
+ }
1556
+ buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
1557
+ if (buf_output == nullptr) {
1558
+ LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
1559
+ return 0;
1560
+ }
1561
+ }
1562
+
1563
+ float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
1564
+
1565
+ logits = has_logits ? output_base : nullptr;
1566
+ embd = has_embd ? output_base + logits_size : nullptr;
1567
+
1568
+ // set all ids as invalid (negative)
1569
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1570
+
1571
+ ggml_backend_buffer_clear(buf_output.get(), 0);
1572
+
1573
+ this->n_outputs = 0;
1574
+ this->n_outputs_max = n_outputs_max;
1575
+
1576
+ return n_outputs_max;
1577
+ }
1578
+
1579
+ void llama_context::output_reorder() {
1580
+ auto & out_ids = sbatch.out_ids;
1581
+ if (!out_ids.empty()) {
1582
+ const uint32_t n_vocab = model.vocab.n_tokens();
1583
+ const uint32_t n_embd = model.hparams.n_embd;
1584
+
1585
+ GGML_ASSERT((size_t) n_outputs == out_ids.size());
1586
+
1587
+ // TODO: is there something more efficient which also minimizes swaps?
1588
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1589
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
1590
+ int32_t j_min = i;
1591
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
1592
+ if (out_ids[j] < out_ids[j_min]) {
1593
+ j_min = j;
1594
+ }
1595
+ }
1596
+ if (j_min == i) { continue; }
1597
+ std::swap(out_ids[i], out_ids[j_min]);
1598
+ if (logits_size > 0) {
1599
+ for (uint32_t k = 0; k < n_vocab; k++) {
1600
+ std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1601
+ }
1602
+ }
1603
+ if (embd_size > 0) {
1604
+ for (uint32_t k = 0; k < n_embd; k++) {
1605
+ std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1606
+ }
1607
+ }
1608
+ }
1609
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1610
+ for (int32_t i = 0; i < n_outputs; ++i) {
1611
+ output_ids[out_ids[i]] = i;
1612
+ }
1613
+ out_ids.clear();
1614
+ }
1615
+ }
1616
+
1617
+ //
1618
+ // graph
1619
+ //
1620
+
1621
+ int32_t llama_context::graph_max_nodes() const {
1622
+ return std::max<int32_t>(65536, 5*model.n_tensors());
1623
+ }
1624
+
1625
+ ggml_cgraph * llama_context::graph_init() {
1626
+ ggml_init_params params = {
1627
+ /*.mem_size =*/ buf_compute_meta.size(),
1628
+ /*.mem_buffer =*/ buf_compute_meta.data(),
1629
+ /*.no_alloc =*/ true,
1630
+ };
1631
+
1632
+ ctx_compute.reset(ggml_init(params));
1633
+
1634
+ return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1635
+ }
1636
+
1637
+ llm_graph_result_ptr llama_context::graph_build(
1638
+ ggml_context * ctx,
1639
+ ggml_cgraph * gf,
1640
+ const llama_ubatch & ubatch,
1641
+ llm_graph_type gtype) {
1642
+ return model.build_graph(
1643
+ {
1644
+ /*.ctx =*/ ctx,
1645
+ /*.arch =*/ model.arch,
1646
+ /*.hparams =*/ model.hparams,
1647
+ /*.cparams =*/ cparams,
1648
+ /*.ubatch =*/ ubatch,
1649
+ /*.sched =*/ sched.get(),
1650
+ /*.backend_cpu =*/ backend_cpu,
1651
+ /*.cvec =*/ &cvec,
1652
+ /*.loras =*/ &loras,
1653
+ /*.memory =*/ kv_self.get(),
1654
+ /*.cross =*/ &cross,
1655
+ /*.n_outputs =*/ n_outputs,
1656
+ /*.cb =*/ graph_get_cb(),
1657
+ }, gf, gtype);
1658
+ }
1659
+
1660
+ ggml_status llama_context::graph_compute(
1661
+ ggml_cgraph * gf,
1662
+ bool batched) {
1663
+ int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
1664
+ ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
1665
+
1666
+ if (backend_cpu != nullptr) {
1667
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
1668
+ auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
1669
+ set_threadpool_fn(backend_cpu, tp);
1670
+ }
1671
+
1672
+ // set the number of threads for all the backends
1673
+ for (const auto & set_n_threads_fn : set_n_threads_fns) {
1674
+ set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
1675
+ }
1676
+
1677
+ auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf);
1678
+ if (status != GGML_STATUS_SUCCESS) {
1679
+ LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
1680
+ }
1681
+
1682
+ // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
1683
+
1684
+ return status;
1685
+ }
1686
+
1687
+ llm_graph_cb llama_context::graph_get_cb() const {
1688
+ return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
1689
+ if (il >= 0) {
1690
+ ggml_format_name(cur, "%s-%d", name, il);
1691
+ } else {
1692
+ ggml_set_name(cur, name);
1693
+ }
1694
+
1695
+ if (!cparams.offload_kqv) {
1696
+ if (strcmp(name, "kqv_merged_cont") == 0) {
1697
+ // all nodes between the KV store and the attention output are run on the CPU
1698
+ ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
1699
+ }
1700
+ }
1701
+
1702
+ // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1703
+ // FIXME: fix in ggml_backend_sched
1704
+ const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
1705
+ if (ubatch.n_tokens < 32 || full_offload) {
1706
+ if (il != -1 && strcmp(name, "norm") == 0) {
1707
+ const auto & dev_layer = model.dev_layer(il);
1708
+ for (const auto & backend : backends) {
1709
+ if (ggml_backend_get_device(backend.get()) == dev_layer) {
1710
+ if (ggml_backend_supports_op(backend.get(), cur)) {
1711
+ ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
1712
+ }
1713
+ }
1714
+ }
1715
+ }
1716
+ }
1717
+ };
1718
+ }
1719
+
1720
+ //
1721
+ // state save/load
1722
+ //
1723
+
1724
+ class llama_io_write_dummy : public llama_io_write_i {
1725
+ public:
1726
+ llama_io_write_dummy() = default;
1727
+
1728
+ void write(const void * /* src */, size_t size) override {
1729
+ size_written += size;
1730
+ }
1731
+
1732
+ void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
1733
+ size_written += size;
1734
+ }
1735
+
1736
+ size_t n_bytes() override {
1737
+ return size_written;
1738
+ }
1739
+
1740
+ private:
1741
+ size_t size_written = 0;
1742
+ };
1743
+
1744
+ class llama_io_write_buffer : public llama_io_write_i {
1745
+ public:
1746
+ llama_io_write_buffer(
1747
+ uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1748
+
1749
+ void write(const void * src, size_t size) override {
1750
+ if (size > buf_size) {
1751
+ throw std::runtime_error("unexpectedly reached end of buffer");
1752
+ }
1753
+ memcpy(ptr, src, size);
1754
+ ptr += size;
1755
+ size_written += size;
1756
+ buf_size -= size;
1757
+ }
1758
+
1759
+ void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
1760
+ if (size > buf_size) {
1761
+ throw std::runtime_error("unexpectedly reached end of buffer");
1762
+ }
1763
+ ggml_backend_tensor_get(tensor, ptr, offset, size);
1764
+ ptr += size;
1765
+ size_written += size;
1766
+ buf_size -= size;
1767
+ }
1768
+
1769
+ size_t n_bytes() override {
1770
+ return size_written;
1771
+ }
1772
+
1773
+ private:
1774
+ uint8_t * ptr;
1775
+ size_t buf_size = 0;
1776
+ size_t size_written = 0;
1777
+ };
1778
+
1779
+ class llama_io_read_buffer : public llama_io_read_i {
1780
+ public:
1781
+ llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1782
+
1783
+ const uint8_t * read(size_t size) override {
1784
+ const uint8_t * base_ptr = ptr;
1785
+ if (size > buf_size) {
1786
+ throw std::runtime_error("unexpectedly reached end of buffer");
1787
+ }
1788
+ ptr += size;
1789
+ size_read += size;
1790
+ buf_size -= size;
1791
+ return base_ptr;
1792
+ }
1793
+
1794
+ void read_to(void * dst, size_t size) override {
1795
+ memcpy(dst, read(size), size);
1796
+ }
1797
+
1798
+ size_t n_bytes() override {
1799
+ return size_read;
1800
+ }
1801
+
1802
+ private:
1803
+ const uint8_t * ptr;
1804
+ size_t buf_size = 0;
1805
+ size_t size_read = 0;
1806
+ };
1807
+
1808
+ class llama_io_write_file : public llama_io_write_i {
1809
+ public:
1810
+ llama_io_write_file(llama_file * f) : file(f) {}
1811
+
1812
+ void write(const void * src, size_t size) override {
1813
+ file->write_raw(src, size);
1814
+ size_written += size;
1815
+ }
1816
+
1817
+ void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
1464
1818
  temp_buffer.resize(size);
1465
1819
  ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
1466
1820
  write(temp_buffer.data(), temp_buffer.size());
1467
1821
  }
1468
1822
 
1469
- size_t get_size_written() override {
1823
+ size_t n_bytes() override {
1470
1824
  return size_written;
1471
1825
  }
1472
- };
1473
1826
 
1474
- struct llama_data_read_file : llama_data_read {
1827
+ private:
1475
1828
  llama_file * file;
1476
- size_t size_read = 0;
1829
+ size_t size_written = 0;
1477
1830
  std::vector<uint8_t> temp_buffer;
1831
+ };
1478
1832
 
1479
- llama_data_read_file(llama_file * f) : file(f) {}
1833
+ class llama_io_read_file : public llama_io_read_i {
1834
+ public:
1835
+ llama_io_read_file(llama_file * f) : file(f) {}
1480
1836
 
1481
1837
  void read_to(void * dst, size_t size) override {
1482
1838
  file->read_raw(dst, size);
@@ -1489,89 +1845,78 @@ struct llama_data_read_file : llama_data_read {
1489
1845
  return temp_buffer.data();
1490
1846
  }
1491
1847
 
1492
- size_t get_size_read() override {
1848
+ size_t n_bytes() override {
1493
1849
  return size_read;
1494
1850
  }
1495
- };
1496
1851
 
1497
- /** copy state data into either a buffer or file depending on the passed in context
1498
- *
1499
- * file context:
1500
- * llama_file file("/path", "wb");
1501
- * llama_data_write_file data_ctx(&file);
1502
- * llama_state_get_data_internal(ctx, data_ctx);
1503
- *
1504
- * buffer context:
1505
- * std::vector<uint8_t> buf(max_size, 0);
1506
- * llama_data_write_buffer data_ctx(buf.data(), max_size);
1507
- * llama_state_get_data_internal(ctx, data_ctx);
1508
- *
1509
- */
1510
- static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
1511
- llama_synchronize(ctx);
1512
-
1513
- data_ctx.write_model_info(ctx);
1514
-
1515
- // copy outputs
1516
- data_ctx.write_output_ids(ctx);
1517
- data_ctx.write_logits(ctx);
1518
- data_ctx.write_embeddings(ctx);
1519
-
1520
- data_ctx.write_kv_cache(ctx);
1852
+ private:
1853
+ llama_file * file;
1854
+ size_t size_read = 0;
1855
+ std::vector<uint8_t> temp_buffer;
1856
+ };
1521
1857
 
1522
- return data_ctx.get_size_written();
1858
+ size_t llama_context::state_get_size() {
1859
+ llama_io_write_dummy io;
1860
+ try {
1861
+ return state_write_data(io);
1862
+ } catch (const std::exception & err) {
1863
+ LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1864
+ return 0;
1865
+ }
1523
1866
  }
1524
1867
 
1525
- size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
1526
- llama_data_write_buffer data_ctx(dst, size);
1868
+ size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
1869
+ llama_io_write_buffer io(dst, size);
1527
1870
  try {
1528
- return llama_state_get_data_internal(ctx, data_ctx);
1871
+ return state_write_data(io);
1529
1872
  } catch (const std::exception & err) {
1530
1873
  LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1531
1874
  return 0;
1532
1875
  }
1533
1876
  }
1534
1877
 
1535
- // Returns the *actual* size of the state.
1536
- // Intended to be used when saving to state to a buffer.
1537
- size_t llama_state_get_size(struct llama_context * ctx) {
1538
- llama_data_write_dummy data_ctx;
1878
+ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1879
+ llama_io_read_buffer io(src, size);
1539
1880
  try {
1540
- return llama_state_get_data_internal(ctx, data_ctx);
1881
+ return state_read_data(io);
1541
1882
  } catch (const std::exception & err) {
1542
- LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1883
+ LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1543
1884
  return 0;
1544
1885
  }
1545
1886
  }
1546
1887
 
1547
- static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
1548
- llama_synchronize(ctx);
1549
-
1550
- data_ctx.read_model_info(ctx);
1551
-
1552
- // set outputs
1553
- data_ctx.read_output_ids(ctx);
1554
- data_ctx.read_logits(ctx);
1555
- data_ctx.read_embeddings(ctx);
1556
-
1557
- data_ctx.read_kv_cache(ctx);
1888
+ size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
1889
+ llama_io_write_dummy io;
1890
+ try {
1891
+ return state_seq_write_data(io, seq_id);
1892
+ } catch (const std::exception & err) {
1893
+ LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1894
+ return 0;
1895
+ }
1896
+ }
1558
1897
 
1559
- return data_ctx.get_size_read();
1898
+ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
1899
+ llama_io_write_buffer io(dst, size);
1900
+ try {
1901
+ return state_seq_write_data(io, seq_id);
1902
+ } catch (const std::exception & err) {
1903
+ LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1904
+ return 0;
1905
+ }
1560
1906
  }
1561
1907
 
1562
- // Sets the state reading from the specified source address
1563
- size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
1564
- llama_data_read_buffer data_ctx(src, size);
1908
+ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
1909
+ llama_io_read_buffer io(src, size);
1565
1910
  try {
1566
- return llama_state_set_data_internal(ctx, data_ctx);
1911
+ return state_seq_read_data(io, seq_id);
1567
1912
  } catch (const std::exception & err) {
1568
1913
  LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1569
1914
  return 0;
1570
1915
  }
1571
1916
  }
1572
1917
 
1573
- static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1574
- llama_file file(path_session, "rb");
1918
+ bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1919
+ llama_file file(filepath, "rb");
1575
1920
 
1576
1921
  // sanity checks
1577
1922
  {
@@ -1601,28 +1946,20 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
1601
1946
  {
1602
1947
  const size_t n_state_size_cur = file.size() - file.tell();
1603
1948
 
1604
- llama_data_read_file data_ctx(&file);
1605
- const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
1949
+ llama_io_read_file io( &file);
1950
+ const size_t n_read = state_read_data(io);
1606
1951
 
1607
1952
  if (n_read != n_state_size_cur) {
1608
1953
  LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
1609
1954
  return false;
1610
1955
  }
1611
1956
  }
1612
- return true;
1613
- }
1614
1957
 
1615
- bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1616
- try {
1617
- return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
1618
- } catch (const std::exception & err) {
1619
- LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
1620
- return false;
1621
- }
1958
+ return true;
1622
1959
  }
1623
1960
 
1624
- static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
1625
- llama_file file(path_session, "wb");
1961
+ bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
1962
+ llama_file file(filepath, "wb");
1626
1963
 
1627
1964
  file.write_u32(LLAMA_SESSION_MAGIC);
1628
1965
  file.write_u32(LLAMA_SESSION_VERSION);
@@ -1632,63 +1969,56 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
1632
1969
  file.write_raw(tokens, sizeof(llama_token) * n_token_count);
1633
1970
 
1634
1971
  // save the context state using stream saving
1635
- llama_data_write_file data_ctx(&file);
1636
- llama_state_get_data_internal(ctx, data_ctx);
1972
+ llama_io_write_file io(&file);
1973
+ state_write_data(io);
1637
1974
 
1638
1975
  return true;
1639
1976
  }
1640
1977
 
1641
- bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
1642
- try {
1643
- return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
1644
- } catch (const std::exception & err) {
1645
- LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
1646
- return false;
1647
- }
1648
- }
1978
+ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1979
+ llama_file file(filepath, "rb");
1649
1980
 
1650
- static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
1651
- llama_synchronize(ctx);
1981
+ // version checks
1982
+ {
1983
+ const uint32_t magic = file.read_u32();
1984
+ const uint32_t version = file.read_u32();
1652
1985
 
1653
- data_ctx.write_kv_cache(ctx, seq_id);
1986
+ if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
1987
+ LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
1988
+ return 0;
1989
+ }
1990
+ }
1654
1991
 
1655
- return data_ctx.get_size_written();
1656
- }
1992
+ // load the prompt
1993
+ {
1994
+ const uint32_t n_token_count = file.read_u32();
1657
1995
 
1658
- size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
1659
- llama_data_write_dummy data_ctx;
1660
- return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
1661
- }
1996
+ if (n_token_count > n_token_capacity) {
1997
+ LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1998
+ return 0;
1999
+ }
1662
2000
 
1663
- size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
1664
- llama_data_write_buffer data_ctx(dst, size);
1665
- try {
1666
- return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
1667
- } catch (const std::exception & err) {
1668
- LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
1669
- return 0;
2001
+ file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
2002
+ *n_token_count_out = n_token_count;
1670
2003
  }
1671
- }
1672
2004
 
1673
- static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
1674
- llama_synchronize(ctx);
1675
-
1676
- data_ctx.read_kv_cache(ctx, dest_seq_id);
2005
+ // restore the context state
2006
+ {
2007
+ const size_t state_size = file.size() - file.tell();
2008
+ llama_io_read_file io(&file);
2009
+ const size_t nread = state_seq_read_data(io, seq_id);
2010
+ if (!nread) {
2011
+ LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
2012
+ return 0;
2013
+ }
2014
+ GGML_ASSERT(nread <= state_size);
2015
+ GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
2016
+ }
1677
2017
 
1678
- return data_ctx.get_size_read();
2018
+ return file.tell();
1679
2019
  }
1680
2020
 
1681
- size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
1682
- llama_data_read_buffer data_ctx(src, size);
1683
- try {
1684
- return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
1685
- } catch (const std::exception & err) {
1686
- LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
1687
- return 0;
1688
- }
1689
- }
1690
-
1691
- static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
2021
+ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
1692
2022
  llama_file file(filepath, "wb");
1693
2023
 
1694
2024
  file.write_u32(LLAMA_STATE_SEQ_MAGIC);
@@ -1699,77 +2029,778 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
1699
2029
  file.write_raw(tokens, sizeof(llama_token) * n_token_count);
1700
2030
 
1701
2031
  // save the context state using stream saving
1702
- llama_data_write_file data_ctx(&file);
1703
- llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
2032
+ llama_io_write_file io(&file);
2033
+ state_seq_write_data(io, seq_id);
1704
2034
 
1705
2035
  const size_t res = file.tell();
1706
- GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
2036
+ GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
2037
+
1707
2038
  return res;
1708
2039
  }
1709
2040
 
1710
- static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1711
- llama_file file(filepath, "rb");
2041
+ size_t llama_context::state_write_data(llama_io_write_i & io) {
2042
+ LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
1712
2043
 
1713
- // version checks
2044
+ // write model info
1714
2045
  {
1715
- const uint32_t magic = file.read_u32();
1716
- const uint32_t version = file.read_u32();
2046
+ LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
1717
2047
 
1718
- if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
1719
- LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
1720
- return 0;
2048
+ const std::string arch_str = llm_arch_name(model.arch);
2049
+ io.write_string(arch_str);
2050
+ // TODO: add more model-specific info which should prevent loading the session file if not identical
2051
+ }
2052
+
2053
+ // write output ids
2054
+ {
2055
+ LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2056
+
2057
+ output_reorder();
2058
+
2059
+ const auto n_outputs = this->n_outputs;
2060
+ const auto & output_ids = this->output_ids;
2061
+
2062
+ std::vector<int32_t> w_output_pos;
2063
+
2064
+ GGML_ASSERT(n_outputs <= n_outputs_max);
2065
+
2066
+ w_output_pos.resize(n_outputs);
2067
+
2068
+ // build a more compact representation of the output ids
2069
+ for (size_t i = 0; i < n_batch(); ++i) {
2070
+ // map an output id to a position in the batch
2071
+ int32_t pos = output_ids[i];
2072
+ if (pos >= 0) {
2073
+ GGML_ASSERT(pos < n_outputs);
2074
+ w_output_pos[pos] = i;
2075
+ }
2076
+ }
2077
+
2078
+ io.write(&n_outputs, sizeof(n_outputs));
2079
+
2080
+ if (n_outputs) {
2081
+ io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
1721
2082
  }
1722
2083
  }
1723
2084
 
1724
- // load the prompt
2085
+ // write logits
1725
2086
  {
1726
- const uint32_t n_token_count = file.read_u32();
2087
+ LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
1727
2088
 
1728
- if (n_token_count > n_token_capacity) {
1729
- LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1730
- return 0;
2089
+ const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
2090
+
2091
+ io.write(&logits_size, sizeof(logits_size));
2092
+
2093
+ if (logits_size) {
2094
+ io.write(logits, logits_size * sizeof(float));
1731
2095
  }
2096
+ }
1732
2097
 
1733
- file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
1734
- *n_token_count_out = n_token_count;
2098
+ // write embeddings
2099
+ {
2100
+ LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
2101
+
2102
+ const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
2103
+
2104
+ io.write(&embd_size, sizeof(embd_size));
2105
+
2106
+ if (embd_size) {
2107
+ io.write(embd, embd_size * sizeof(float));
2108
+ }
1735
2109
  }
1736
2110
 
1737
- // restore the context state
2111
+ LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2112
+ kv_self->state_write(io);
2113
+
2114
+ return io.n_bytes();
2115
+ }
2116
+
2117
+ size_t llama_context::state_read_data(llama_io_read_i & io) {
2118
+ LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
2119
+
2120
+ // read model info
1738
2121
  {
1739
- const size_t state_size = file.size() - file.tell();
1740
- llama_data_read_file data_ctx(&file);
1741
- const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
1742
- if (!nread) {
1743
- LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1744
- return 0;
2122
+ LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
2123
+
2124
+ const std::string cur_arch_str = llm_arch_name(model.arch);
2125
+
2126
+ std::string arch_str;
2127
+ io.read_string(arch_str);
2128
+ if (cur_arch_str != arch_str) {
2129
+ throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
1745
2130
  }
1746
- GGML_ASSERT(nread <= state_size);
1747
- GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
2131
+ // TODO: add more info which needs to be identical but which is not verified otherwise
1748
2132
  }
1749
2133
 
1750
- return file.tell();
2134
+ // read output ids
2135
+ {
2136
+ LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
2137
+
2138
+ auto n_outputs = this->n_outputs;
2139
+ io.read_to(&n_outputs, sizeof(n_outputs));
2140
+
2141
+ if (n_outputs > output_reserve(n_outputs)) {
2142
+ throw std::runtime_error("could not reserve outputs");
2143
+ }
2144
+
2145
+ std::vector<int32_t> output_pos;
2146
+
2147
+ if (n_outputs) {
2148
+ output_pos.resize(n_outputs);
2149
+ io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
2150
+
2151
+ for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
2152
+ int32_t id = output_pos[i];
2153
+ if ((uint32_t) id >= n_batch()) {
2154
+ throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
2155
+ }
2156
+ this->output_ids[id] = i;
2157
+ }
2158
+
2159
+ this->n_outputs = n_outputs;
2160
+ }
2161
+ }
2162
+
2163
+ // read logits
2164
+ {
2165
+ LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
2166
+
2167
+ uint64_t logits_size;
2168
+ io.read_to(&logits_size, sizeof(logits_size));
2169
+
2170
+ if (this->logits_size < logits_size) {
2171
+ throw std::runtime_error("logits buffer too small");
2172
+ }
2173
+
2174
+ if (logits_size) {
2175
+ io.read_to(this->logits, logits_size * sizeof(float));
2176
+ }
2177
+ }
2178
+
2179
+ // read embeddings
2180
+ {
2181
+ LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
2182
+
2183
+ uint64_t embd_size;
2184
+ io.read_to(&embd_size, sizeof(embd_size));
2185
+
2186
+ if (this->embd_size < embd_size) {
2187
+ throw std::runtime_error("embeddings buffer too small");
2188
+ }
2189
+
2190
+ if (embd_size) {
2191
+ io.read_to(this->embd, embd_size * sizeof(float));
2192
+ }
2193
+ }
2194
+
2195
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2196
+ kv_self->state_read(io);
2197
+
2198
+ return io.n_bytes();
1751
2199
  }
1752
2200
 
1753
- size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
2201
+ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2202
+ GGML_UNUSED(seq_id);
2203
+
2204
+ kv_self->state_write(io, seq_id);
2205
+
2206
+ return io.n_bytes();
2207
+ }
2208
+
2209
+ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2210
+ GGML_UNUSED(seq_id);
2211
+
2212
+ kv_self->state_read(io, seq_id);
2213
+
2214
+ return io.n_bytes();
2215
+ }
2216
+
2217
+ //
2218
+ // perf
2219
+ //
2220
+
2221
+ llama_perf_context_data llama_context::perf_get_data() const {
2222
+ llama_perf_context_data data = {};
2223
+
2224
+ data.t_start_ms = 1e-3 * t_start_us;
2225
+ data.t_load_ms = 1e-3 * t_load_us;
2226
+ data.t_p_eval_ms = 1e-3 * t_p_eval_us;
2227
+ data.t_eval_ms = 1e-3 * t_eval_us;
2228
+ data.n_p_eval = std::max(1, n_p_eval);
2229
+ data.n_eval = std::max(1, n_eval);
2230
+
2231
+ return data;
2232
+ }
2233
+
2234
+ void llama_context::perf_reset() {
2235
+ t_start_us = ggml_time_us();
2236
+ t_eval_us = n_eval = 0;
2237
+ t_p_eval_us = n_p_eval = 0;
2238
+ }
2239
+
2240
+ //
2241
+ // interface implementation
2242
+ //
2243
+
2244
+ llama_context_params llama_context_default_params() {
2245
+ llama_context_params result = {
2246
+ /*.n_ctx =*/ 512,
2247
+ /*.n_batch =*/ 2048,
2248
+ /*.n_ubatch =*/ 512,
2249
+ /*.n_seq_max =*/ 1,
2250
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
2251
+ /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
2252
+ /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2253
+ /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2254
+ /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2255
+ /*.rope_freq_base =*/ 0.0f,
2256
+ /*.rope_freq_scale =*/ 0.0f,
2257
+ /*.yarn_ext_factor =*/ -1.0f,
2258
+ /*.yarn_attn_factor =*/ 1.0f,
2259
+ /*.yarn_beta_fast =*/ 32.0f,
2260
+ /*.yarn_beta_slow =*/ 1.0f,
2261
+ /*.yarn_orig_ctx =*/ 0,
2262
+ /*.defrag_thold =*/ -1.0f,
2263
+ /*.cb_eval =*/ nullptr,
2264
+ /*.cb_eval_user_data =*/ nullptr,
2265
+ /*.type_k =*/ GGML_TYPE_F16,
2266
+ /*.type_v =*/ GGML_TYPE_F16,
2267
+ /*.logits_all =*/ false,
2268
+ /*.embeddings =*/ false,
2269
+ /*.offload_kqv =*/ true,
2270
+ /*.flash_attn =*/ false,
2271
+ /*.no_perf =*/ true,
2272
+ /*.abort_callback =*/ nullptr,
2273
+ /*.abort_callback_data =*/ nullptr,
2274
+ };
2275
+
2276
+ return result;
2277
+ }
2278
+
2279
+ llama_context * llama_init_from_model(
2280
+ llama_model * model,
2281
+ llama_context_params params) {
2282
+ if (!model) {
2283
+ LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
2284
+ return nullptr;
2285
+ }
2286
+
2287
+ if (params.n_batch == 0 && params.n_ubatch == 0) {
2288
+ LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
2289
+ return nullptr;
2290
+ }
2291
+
2292
+ if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
2293
+ LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
2294
+ return nullptr;
2295
+ }
2296
+
2297
+ if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2298
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2299
+ params.flash_attn = false;
2300
+ }
2301
+
2302
+ if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
2303
+ LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
2304
+ params.flash_attn = false;
2305
+ }
2306
+
2307
+ if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2308
+ LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2309
+ return nullptr;
2310
+ }
2311
+
1754
2312
  try {
1755
- return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
2313
+ auto * ctx = new llama_context(*model, params);
2314
+ return ctx;
2315
+ } catch (const std::exception & err) {
2316
+ LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
2317
+ }
2318
+
2319
+ return nullptr;
2320
+ }
2321
+
2322
+ // deprecated
2323
+ llama_context * llama_new_context_with_model(
2324
+ llama_model * model,
2325
+ llama_context_params params) {
2326
+ return llama_init_from_model(model, params);
2327
+ }
2328
+
2329
+ void llama_free(llama_context * ctx) {
2330
+ delete ctx;
2331
+ }
2332
+
2333
+ uint32_t llama_n_ctx(const llama_context * ctx) {
2334
+ return ctx->n_ctx();
2335
+ }
2336
+
2337
+ uint32_t llama_n_batch(const llama_context * ctx) {
2338
+ return ctx->n_batch();
2339
+ }
2340
+
2341
+ uint32_t llama_n_ubatch(const llama_context * ctx) {
2342
+ return ctx->n_ubatch();
2343
+ }
2344
+
2345
+ uint32_t llama_n_seq_max(const llama_context * ctx) {
2346
+ return ctx->n_seq_max();
2347
+ }
2348
+
2349
+ const llama_model * llama_get_model(const llama_context * ctx) {
2350
+ return &ctx->get_model();
2351
+ }
2352
+
2353
+ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2354
+ return ctx->get_kv_self();
2355
+ }
2356
+
2357
+ void llama_kv_self_update(llama_context * ctx) {
2358
+ ctx->kv_self_update();
2359
+ }
2360
+
2361
+ enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
2362
+ return ctx->pooling_type();
2363
+ }
2364
+
2365
+ void llama_attach_threadpool(
2366
+ llama_context * ctx,
2367
+ ggml_threadpool_t threadpool,
2368
+ ggml_threadpool_t threadpool_batch) {
2369
+ ctx->attach_threadpool(threadpool, threadpool_batch);
2370
+ }
2371
+
2372
+ void llama_detach_threadpool(llama_context * ctx) {
2373
+ ctx->detach_threadpool();
2374
+ }
2375
+
2376
+ void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
2377
+ ctx->set_n_threads(n_threads, n_threads_batch);
2378
+ }
2379
+
2380
+ int32_t llama_n_threads(llama_context * ctx) {
2381
+ return ctx->n_threads();
2382
+ }
2383
+
2384
+ int32_t llama_n_threads_batch(llama_context * ctx) {
2385
+ return ctx->n_threads_batch();
2386
+ }
2387
+
2388
+ void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
2389
+ ctx->set_abort_callback(abort_callback, abort_callback_data);
2390
+ }
2391
+
2392
+ void llama_set_embeddings(llama_context * ctx, bool embeddings) {
2393
+ ctx->set_embeddings(embeddings);
2394
+ }
2395
+
2396
+ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
2397
+ ctx->set_causal_attn(causal_attn);
2398
+ }
2399
+
2400
+ void llama_set_warmup(llama_context * ctx, bool warmup) {
2401
+ ctx->set_warmup(warmup);
2402
+ }
2403
+
2404
+ void llama_synchronize(llama_context * ctx) {
2405
+ ctx->synchronize();
2406
+ }
2407
+
2408
+ float * llama_get_logits(llama_context * ctx) {
2409
+ ctx->synchronize();
2410
+
2411
+ return ctx->get_logits();
2412
+ }
2413
+
2414
+ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
2415
+ ctx->synchronize();
2416
+
2417
+ return ctx->get_logits_ith(i);
2418
+ }
2419
+
2420
+ float * llama_get_embeddings(llama_context * ctx) {
2421
+ ctx->synchronize();
2422
+
2423
+ return ctx->get_embeddings();
2424
+ }
2425
+
2426
+ float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
2427
+ ctx->synchronize();
2428
+
2429
+ return ctx->get_embeddings_ith(i);
2430
+ }
2431
+
2432
+ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
2433
+ ctx->synchronize();
2434
+
2435
+ return ctx->get_embeddings_seq(seq_id);
2436
+ }
2437
+
2438
+ // llama adapter API
2439
+
2440
+ int32_t llama_set_adapter_lora(
2441
+ llama_context * ctx,
2442
+ llama_adapter_lora * adapter,
2443
+ float scale) {
2444
+ ctx->set_adapter_lora(adapter, scale);
2445
+
2446
+ return 0;
2447
+ }
2448
+
2449
+ int32_t llama_rm_adapter_lora(
2450
+ llama_context * ctx,
2451
+ llama_adapter_lora * adapter) {
2452
+ bool res = ctx->rm_adapter_lora(adapter);
2453
+
2454
+ return res ? 0 : -1;
2455
+ }
2456
+
2457
+ void llama_clear_adapter_lora(llama_context * ctx) {
2458
+ ctx->clear_adapter_lora();
2459
+ }
2460
+
2461
+ int32_t llama_apply_adapter_cvec(
2462
+ llama_context * ctx,
2463
+ const float * data,
2464
+ size_t len,
2465
+ int32_t n_embd,
2466
+ int32_t il_start,
2467
+ int32_t il_end) {
2468
+ bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
2469
+
2470
+ return res ? 0 : -1;
2471
+ }
2472
+
2473
+ //
2474
+ // kv cache view
2475
+ //
2476
+
2477
+ llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
2478
+ const auto * kv = ctx->get_kv_self();
2479
+ if (kv == nullptr) {
2480
+ LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2481
+ return {};
2482
+ }
2483
+
2484
+ return llama_kv_cache_view_init(*kv, n_seq_max);
2485
+ }
2486
+
2487
+ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
2488
+ const auto * kv = ctx->get_kv_self();
2489
+ if (kv == nullptr) {
2490
+ LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2491
+ return;
2492
+ }
2493
+
2494
+ llama_kv_cache_view_update(view, kv);
2495
+ }
2496
+
2497
+ //
2498
+ // kv cache
2499
+ //
2500
+
2501
+ // deprecated
2502
+ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2503
+ return llama_kv_self_n_tokens(ctx);
2504
+ }
2505
+
2506
+ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2507
+ return llama_kv_cache_n_tokens(ctx->get_kv_self());
2508
+ }
2509
+
2510
+ // deprecated
2511
+ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2512
+ return llama_kv_self_used_cells(ctx);
2513
+ }
2514
+
2515
+ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2516
+ return llama_kv_cache_used_cells(ctx->get_kv_self());
2517
+ }
2518
+
2519
+ // deprecated
2520
+ void llama_kv_cache_clear(llama_context * ctx) {
2521
+ llama_kv_self_clear(ctx);
2522
+ }
2523
+
2524
+ void llama_kv_self_clear(llama_context * ctx) {
2525
+ llama_kv_cache_clear(ctx->get_kv_self());
2526
+ }
2527
+
2528
+ // deprecated
2529
+ bool llama_kv_cache_seq_rm(
2530
+ llama_context * ctx,
2531
+ llama_seq_id seq_id,
2532
+ llama_pos p0,
2533
+ llama_pos p1) {
2534
+ return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
2535
+ }
2536
+
2537
+ bool llama_kv_self_seq_rm(
2538
+ llama_context * ctx,
2539
+ llama_seq_id seq_id,
2540
+ llama_pos p0,
2541
+ llama_pos p1) {
2542
+ return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
2543
+ }
2544
+
2545
+ // deprecated
2546
+ void llama_kv_cache_seq_cp(
2547
+ llama_context * ctx,
2548
+ llama_seq_id seq_id_src,
2549
+ llama_seq_id seq_id_dst,
2550
+ llama_pos p0,
2551
+ llama_pos p1) {
2552
+ return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2553
+ }
2554
+
2555
+ void llama_kv_self_seq_cp(
2556
+ llama_context * ctx,
2557
+ llama_seq_id seq_id_src,
2558
+ llama_seq_id seq_id_dst,
2559
+ llama_pos p0,
2560
+ llama_pos p1) {
2561
+ return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
2562
+ }
2563
+
2564
+ // deprecated
2565
+ void llama_kv_cache_seq_keep(
2566
+ llama_context * ctx,
2567
+ llama_seq_id seq_id) {
2568
+ return llama_kv_self_seq_keep(ctx, seq_id);
2569
+ }
2570
+
2571
+ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2572
+ return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
2573
+ }
2574
+
2575
+ // deprecated
2576
+ void llama_kv_cache_seq_add(
2577
+ llama_context * ctx,
2578
+ llama_seq_id seq_id,
2579
+ llama_pos p0,
2580
+ llama_pos p1,
2581
+ llama_pos delta) {
2582
+ return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2583
+ }
2584
+
2585
+ void llama_kv_self_seq_add(
2586
+ llama_context * ctx,
2587
+ llama_seq_id seq_id,
2588
+ llama_pos p0,
2589
+ llama_pos p1,
2590
+ llama_pos delta) {
2591
+ return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
2592
+ }
2593
+
2594
+ // deprecated
2595
+ void llama_kv_cache_seq_div(
2596
+ llama_context * ctx,
2597
+ llama_seq_id seq_id,
2598
+ llama_pos p0,
2599
+ llama_pos p1,
2600
+ int d) {
2601
+ return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2602
+ }
2603
+
2604
+ void llama_kv_self_seq_div(
2605
+ llama_context * ctx,
2606
+ llama_seq_id seq_id,
2607
+ llama_pos p0,
2608
+ llama_pos p1,
2609
+ int d) {
2610
+ return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
2611
+ }
2612
+
2613
+ // deprecated
2614
+ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2615
+ return llama_kv_self_seq_pos_max(ctx, seq_id);
2616
+ }
2617
+
2618
+ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2619
+ return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
2620
+ }
2621
+
2622
+ // deprecated
2623
+ void llama_kv_cache_defrag(llama_context * ctx) {
2624
+ return llama_kv_self_defrag(ctx);
2625
+ }
2626
+
2627
+ void llama_kv_self_defrag(llama_context * ctx) {
2628
+ llama_kv_cache_defrag(ctx->get_kv_self());
2629
+ }
2630
+
2631
+ // deprecated
2632
+ bool llama_kv_cache_can_shift(const llama_context * ctx) {
2633
+ return llama_kv_self_can_shift(ctx);
2634
+ }
2635
+
2636
+ bool llama_kv_self_can_shift(const llama_context * ctx) {
2637
+ return llama_kv_cache_can_shift(ctx->get_kv_self());
2638
+ }
2639
+
2640
+ // deprecated
2641
+ void llama_kv_cache_update(llama_context * ctx) {
2642
+ llama_kv_self_update(ctx);
2643
+ }
2644
+
2645
+ // llama state API
2646
+
2647
+ // deprecated
2648
+ size_t llama_get_state_size(llama_context * ctx) {
2649
+ return llama_state_get_size(ctx);
2650
+ }
2651
+
2652
+ // deprecated
2653
+ size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
2654
+ return llama_state_get_data(ctx, dst, -1);
2655
+ }
2656
+
2657
+ // deprecated
2658
+ size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
2659
+ return llama_state_set_data(ctx, src, -1);
2660
+ }
2661
+
2662
+ // deprecated
2663
+ bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2664
+ return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
2665
+ }
2666
+
2667
+ // deprecated
2668
+ bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2669
+ return llama_state_save_file(ctx, path_session, tokens, n_token_count);
2670
+ }
2671
+
2672
+ // Returns the *actual* size of the state.
2673
+ // Intended to be used when saving to state to a buffer.
2674
+ size_t llama_state_get_size(llama_context * ctx) {
2675
+ return ctx->state_get_size();
2676
+ }
2677
+
2678
+ size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
2679
+ ctx->synchronize();
2680
+
2681
+ return ctx->state_get_data(dst, size);
2682
+ }
2683
+
2684
+ // Sets the state reading from the specified source address
2685
+ size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
2686
+ ctx->synchronize();
2687
+
2688
+ return ctx->state_set_data(src, size);
2689
+ }
2690
+
2691
+ bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2692
+ ctx->synchronize();
2693
+
2694
+ try {
2695
+ return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
2696
+ } catch (const std::exception & err) {
2697
+ LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
2698
+ return false;
2699
+ }
2700
+ }
2701
+
2702
+ bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2703
+ ctx->synchronize();
2704
+
2705
+ try {
2706
+ return ctx->state_save_file(path_session, tokens, n_token_count);
2707
+ } catch (const std::exception & err) {
2708
+ LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
2709
+ return false;
2710
+ }
2711
+ }
2712
+
2713
+ size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2714
+ return ctx->state_seq_get_size(seq_id);
2715
+ }
2716
+
2717
+ size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2718
+ ctx->synchronize();
2719
+
2720
+ return ctx->state_seq_get_data(seq_id, dst, size);
2721
+ }
2722
+
2723
+ size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2724
+ ctx->synchronize();
2725
+
2726
+ return ctx->state_seq_set_data(seq_id, src, size);
2727
+ }
2728
+
2729
+ size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
2730
+ ctx->synchronize();
2731
+
2732
+ try {
2733
+ return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
1756
2734
  } catch (const std::exception & err) {
1757
2735
  LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
1758
2736
  return 0;
1759
2737
  }
1760
2738
  }
1761
2739
 
1762
- size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2740
+ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2741
+ ctx->synchronize();
2742
+
1763
2743
  try {
1764
- return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
2744
+ return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
1765
2745
  } catch (const std::exception & err) {
1766
2746
  LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
1767
2747
  return 0;
1768
2748
  }
1769
2749
  }
1770
2750
 
1771
- const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1772
- struct llama_context * ctx
1773
- ) {
1774
- return ctx->model.tensors_by_name;
2751
+ ///
2752
+
2753
+ int32_t llama_encode(
2754
+ llama_context * ctx,
2755
+ llama_batch batch) {
2756
+ const int ret = ctx->encode(batch);
2757
+ if (ret != 0) {
2758
+ LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2759
+ }
2760
+
2761
+ return ret;
2762
+ }
2763
+
2764
+ int32_t llama_decode(
2765
+ llama_context * ctx,
2766
+ llama_batch batch) {
2767
+ const int ret = ctx->decode(batch);
2768
+ if (ret != 0) {
2769
+ LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2770
+ }
2771
+
2772
+ return ret;
2773
+ }
2774
+
2775
+ //
2776
+ // perf
2777
+ //
2778
+
2779
+ llama_perf_context_data llama_perf_context(const llama_context * ctx) {
2780
+ llama_perf_context_data data = {};
2781
+
2782
+ if (ctx == nullptr) {
2783
+ return data;
2784
+ }
2785
+
2786
+ data = ctx->perf_get_data();
2787
+
2788
+ return data;
2789
+ }
2790
+
2791
+ void llama_perf_context_print(const llama_context * ctx) {
2792
+ const auto data = llama_perf_context(ctx);
2793
+
2794
+ const double t_end_ms = 1e-3 * ggml_time_us();
2795
+
2796
+ LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
2797
+ LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
2798
+ __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
2799
+ LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2800
+ __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2801
+ LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2802
+ }
2803
+
2804
+ void llama_perf_context_reset(llama_context * ctx) {
2805
+ ctx->perf_reset();
1775
2806
  }