llama-cpp-capacitor 0.0.5 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. package/cpp/LICENSE +21 -0
  2. package/cpp/README.md +4 -0
  3. package/cpp/anyascii.c +22223 -0
  4. package/cpp/anyascii.h +42 -0
  5. package/cpp/chat-parser.cpp +393 -0
  6. package/cpp/chat-parser.h +120 -0
  7. package/cpp/chat.cpp +2315 -0
  8. package/cpp/chat.h +221 -0
  9. package/cpp/common.cpp +1619 -0
  10. package/cpp/common.h +744 -0
  11. package/cpp/ggml-alloc.c +1028 -0
  12. package/cpp/ggml-alloc.h +76 -0
  13. package/cpp/ggml-backend-impl.h +255 -0
  14. package/cpp/ggml-backend-reg.cpp +600 -0
  15. package/cpp/ggml-backend.cpp +2118 -0
  16. package/cpp/ggml-backend.h +354 -0
  17. package/cpp/ggml-common.h +1878 -0
  18. package/cpp/ggml-cpp.h +39 -0
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2512 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  25. package/cpp/ggml-cpu/arch/arm/quants.c +3650 -0
  26. package/cpp/ggml-cpu/arch/arm/repack.cpp +1891 -0
  27. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  28. package/cpp/ggml-cpu/arch/x86/quants.c +3820 -0
  29. package/cpp/ggml-cpu/arch/x86/repack.cpp +6307 -0
  30. package/cpp/ggml-cpu/arch-fallback.h +215 -0
  31. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  32. package/cpp/ggml-cpu/binary-ops.h +16 -0
  33. package/cpp/ggml-cpu/common.h +73 -0
  34. package/cpp/ggml-cpu/ggml-cpu-impl.h +525 -0
  35. package/cpp/ggml-cpu/ggml-cpu.c +3578 -0
  36. package/cpp/ggml-cpu/ggml-cpu.cpp +672 -0
  37. package/cpp/ggml-cpu/ops.cpp +10587 -0
  38. package/cpp/ggml-cpu/ops.h +114 -0
  39. package/cpp/ggml-cpu/quants.c +1193 -0
  40. package/cpp/ggml-cpu/quants.h +97 -0
  41. package/cpp/ggml-cpu/repack.cpp +1982 -0
  42. package/cpp/ggml-cpu/repack.h +120 -0
  43. package/cpp/ggml-cpu/simd-mappings.h +1184 -0
  44. package/cpp/ggml-cpu/traits.cpp +36 -0
  45. package/cpp/ggml-cpu/traits.h +38 -0
  46. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  47. package/cpp/ggml-cpu/unary-ops.h +28 -0
  48. package/cpp/ggml-cpu/vec.cpp +348 -0
  49. package/cpp/ggml-cpu/vec.h +1121 -0
  50. package/cpp/ggml-cpu.h +145 -0
  51. package/cpp/ggml-impl.h +622 -0
  52. package/cpp/ggml-metal-impl.h +688 -0
  53. package/cpp/ggml-metal.h +66 -0
  54. package/cpp/ggml-metal.m +6833 -0
  55. package/cpp/ggml-opt.cpp +1093 -0
  56. package/cpp/ggml-opt.h +256 -0
  57. package/cpp/ggml-quants.c +5324 -0
  58. package/cpp/ggml-quants.h +106 -0
  59. package/cpp/ggml-threading.cpp +12 -0
  60. package/cpp/ggml-threading.h +14 -0
  61. package/cpp/ggml.c +7108 -0
  62. package/cpp/ggml.h +2492 -0
  63. package/cpp/gguf.cpp +1358 -0
  64. package/cpp/gguf.h +202 -0
  65. package/cpp/json-partial.cpp +256 -0
  66. package/cpp/json-partial.h +38 -0
  67. package/cpp/json-schema-to-grammar.cpp +985 -0
  68. package/cpp/json-schema-to-grammar.h +21 -0
  69. package/cpp/llama-adapter.cpp +388 -0
  70. package/cpp/llama-adapter.h +76 -0
  71. package/cpp/llama-arch.cpp +2355 -0
  72. package/cpp/llama-arch.h +499 -0
  73. package/cpp/llama-batch.cpp +875 -0
  74. package/cpp/llama-batch.h +160 -0
  75. package/cpp/llama-chat.cpp +783 -0
  76. package/cpp/llama-chat.h +65 -0
  77. package/cpp/llama-context.cpp +2748 -0
  78. package/cpp/llama-context.h +306 -0
  79. package/cpp/llama-cparams.cpp +5 -0
  80. package/cpp/llama-cparams.h +41 -0
  81. package/cpp/llama-cpp.h +30 -0
  82. package/cpp/llama-grammar.cpp +1229 -0
  83. package/cpp/llama-grammar.h +173 -0
  84. package/cpp/llama-graph.cpp +1891 -0
  85. package/cpp/llama-graph.h +810 -0
  86. package/cpp/llama-hparams.cpp +180 -0
  87. package/cpp/llama-hparams.h +233 -0
  88. package/cpp/llama-impl.cpp +167 -0
  89. package/cpp/llama-impl.h +61 -0
  90. package/cpp/llama-io.cpp +15 -0
  91. package/cpp/llama-io.h +35 -0
  92. package/cpp/llama-kv-cache-iswa.cpp +318 -0
  93. package/cpp/llama-kv-cache-iswa.h +135 -0
  94. package/cpp/llama-kv-cache.cpp +2059 -0
  95. package/cpp/llama-kv-cache.h +374 -0
  96. package/cpp/llama-kv-cells.h +491 -0
  97. package/cpp/llama-memory-hybrid.cpp +258 -0
  98. package/cpp/llama-memory-hybrid.h +137 -0
  99. package/cpp/llama-memory-recurrent.cpp +1146 -0
  100. package/cpp/llama-memory-recurrent.h +179 -0
  101. package/cpp/llama-memory.cpp +59 -0
  102. package/cpp/llama-memory.h +119 -0
  103. package/cpp/llama-mmap.cpp +600 -0
  104. package/cpp/llama-mmap.h +68 -0
  105. package/cpp/llama-model-loader.cpp +1164 -0
  106. package/cpp/llama-model-loader.h +170 -0
  107. package/cpp/llama-model-saver.cpp +282 -0
  108. package/cpp/llama-model-saver.h +37 -0
  109. package/cpp/llama-model.cpp +19042 -0
  110. package/cpp/llama-model.h +491 -0
  111. package/cpp/llama-sampling.cpp +2575 -0
  112. package/cpp/llama-sampling.h +32 -0
  113. package/cpp/llama-vocab.cpp +3792 -0
  114. package/cpp/llama-vocab.h +176 -0
  115. package/cpp/llama.cpp +358 -0
  116. package/cpp/llama.h +1373 -0
  117. package/cpp/log.cpp +427 -0
  118. package/cpp/log.h +103 -0
  119. package/cpp/minja/chat-template.hpp +550 -0
  120. package/cpp/minja/minja.hpp +3009 -0
  121. package/cpp/nlohmann/json.hpp +25526 -0
  122. package/cpp/nlohmann/json_fwd.hpp +187 -0
  123. package/cpp/regex-partial.cpp +204 -0
  124. package/cpp/regex-partial.h +56 -0
  125. package/cpp/rn-completion.cpp +681 -0
  126. package/cpp/rn-completion.h +116 -0
  127. package/cpp/rn-llama.cpp +345 -0
  128. package/cpp/rn-llama.h +149 -0
  129. package/cpp/rn-mtmd.hpp +602 -0
  130. package/cpp/rn-tts.cpp +591 -0
  131. package/cpp/rn-tts.h +59 -0
  132. package/cpp/sampling.cpp +579 -0
  133. package/cpp/sampling.h +107 -0
  134. package/cpp/tools/mtmd/clip-impl.h +473 -0
  135. package/cpp/tools/mtmd/clip.cpp +4322 -0
  136. package/cpp/tools/mtmd/clip.h +106 -0
  137. package/cpp/tools/mtmd/miniaudio/miniaudio.h +93468 -0
  138. package/cpp/tools/mtmd/mtmd-audio.cpp +769 -0
  139. package/cpp/tools/mtmd/mtmd-audio.h +47 -0
  140. package/cpp/tools/mtmd/mtmd-helper.cpp +460 -0
  141. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  142. package/cpp/tools/mtmd/mtmd.cpp +1066 -0
  143. package/cpp/tools/mtmd/mtmd.h +298 -0
  144. package/cpp/tools/mtmd/stb/stb_image.h +7988 -0
  145. package/cpp/unicode-data.cpp +7034 -0
  146. package/cpp/unicode-data.h +20 -0
  147. package/cpp/unicode.cpp +1061 -0
  148. package/cpp/unicode.h +68 -0
  149. package/package.json +2 -1
@@ -0,0 +1,318 @@
1
+ #include "llama-kv-cache-iswa.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-model.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+
10
+ //
11
+ // llama_kv_cache_iswa
12
+ //
13
+
14
+ llama_kv_cache_iswa::llama_kv_cache_iswa(
15
+ const llama_model & model,
16
+ lm_ggml_type type_k,
17
+ lm_ggml_type type_v,
18
+ bool v_trans,
19
+ bool offload,
20
+ bool swa_full,
21
+ bool unified,
22
+ uint32_t kv_size,
23
+ uint32_t n_seq_max,
24
+ uint32_t n_ubatch,
25
+ uint32_t n_pad,
26
+ const layer_filter_cb & filter,
27
+ const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28
+
29
+ // chain filters
30
+ const layer_filter_cb filter_base = [&](int32_t il) {
31
+ if (filter && !filter(il)) {
32
+ return false;
33
+ }
34
+
35
+ return !model.hparams.is_swa(il);
36
+ };
37
+
38
+ const layer_filter_cb filter_swa = [&](int32_t il) {
39
+ if (filter && !filter(il)) {
40
+ return false;
41
+ }
42
+
43
+ return model.hparams.is_swa(il);
44
+ };
45
+
46
+ const uint32_t size_base = kv_size;
47
+
48
+ uint32_t size_swa = std::min(size_base, LM_GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
49
+
50
+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
51
+ if (swa_full) {
52
+ LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
53
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
54
+
55
+ size_swa = size_base;
56
+ }
57
+
58
+ LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
59
+
60
+ kv_base = std::make_unique<llama_kv_cache>(
61
+ model, type_k, type_v,
62
+ v_trans, offload, unified, size_base, n_seq_max, n_pad,
63
+ 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
64
+
65
+ LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
66
+
67
+ kv_swa = std::make_unique<llama_kv_cache>(
68
+ model, type_k, type_v,
69
+ v_trans, offload, unified, size_swa, n_seq_max, n_pad,
70
+ hparams.n_swa, hparams.swa_type, filter_swa, reuse);
71
+ }
72
+
73
+ void llama_kv_cache_iswa::clear(bool data) {
74
+ kv_base->clear(data);
75
+ kv_swa ->clear(data);
76
+ }
77
+
78
+ bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
79
+ bool res = true;
80
+
81
+ res = res & kv_base->seq_rm(seq_id, p0, p1);
82
+ res = res & kv_swa ->seq_rm(seq_id, p0, p1);
83
+
84
+ return res;
85
+ }
86
+
87
+ void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
88
+ kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
89
+ kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
90
+ }
91
+
92
+ void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
93
+ kv_base->seq_keep(seq_id);
94
+ kv_swa ->seq_keep(seq_id);
95
+ }
96
+
97
+ void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
98
+ kv_base->seq_add(seq_id, p0, p1, shift);
99
+ kv_swa ->seq_add(seq_id, p0, p1, shift);
100
+ }
101
+
102
+ void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
103
+ kv_base->seq_div(seq_id, p0, p1, d);
104
+ kv_swa ->seq_div(seq_id, p0, p1, d);
105
+ }
106
+
107
+ llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
108
+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
109
+ return kv_swa->seq_pos_min(seq_id);
110
+ }
111
+
112
+ llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
113
+ return kv_swa->seq_pos_max(seq_id);
114
+ }
115
+
116
+ llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
117
+ LM_GGML_UNUSED(embd_all);
118
+
119
+ // first try simple split
120
+ do {
121
+ if (!unified) {
122
+ // requires equal splits, so we skip the simple split
123
+ break;
124
+ }
125
+
126
+ balloc.split_reset();
127
+
128
+ std::vector<llama_ubatch> ubatches;
129
+ while (true) {
130
+ auto ubatch = balloc.split_simple(n_ubatch);
131
+
132
+ if (ubatch.n_tokens == 0) {
133
+ break;
134
+ }
135
+
136
+ ubatches.push_back(std::move(ubatch)); // NOLINT
137
+ }
138
+
139
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
140
+ // failed to find a suitable split
141
+ break;
142
+ }
143
+
144
+ auto sinfos_base = kv_base->prepare(ubatches);
145
+ if (sinfos_base.empty()) {
146
+ break;
147
+ }
148
+
149
+ auto sinfos_swa = kv_swa->prepare(ubatches);
150
+ if (sinfos_swa.empty()) {
151
+ break;
152
+ }
153
+
154
+ assert(sinfos_base.size() == sinfos_swa.size());
155
+
156
+ return std::make_unique<llama_kv_cache_iswa_context>(
157
+ this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
158
+ } while (false);
159
+
160
+ // if it fails, try equal split
161
+ do {
162
+ balloc.split_reset();
163
+
164
+ std::vector<llama_ubatch> ubatches;
165
+ while (true) {
166
+ auto ubatch = balloc.split_equal(n_ubatch, !unified);
167
+
168
+ if (ubatch.n_tokens == 0) {
169
+ break;
170
+ }
171
+
172
+ ubatches.push_back(std::move(ubatch)); // NOLINT
173
+ }
174
+
175
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
176
+ // failed to find a suitable split
177
+ break;
178
+ }
179
+
180
+ auto sinfos_base = kv_base->prepare(ubatches);
181
+ if (sinfos_base.empty()) {
182
+ break;
183
+ }
184
+
185
+ auto sinfos_swa = kv_swa->prepare(ubatches);
186
+ if (sinfos_swa.empty()) {
187
+ break;
188
+ }
189
+
190
+ assert(sinfos_base.size() == sinfos_swa.size());
191
+
192
+ return std::make_unique<llama_kv_cache_iswa_context>(
193
+ this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
194
+ } while (false);
195
+
196
+ // TODO: if we fail again, we should attempt different splitting strategies
197
+ // but to do that properly, we first have to refactor the batches to be more flexible
198
+
199
+ return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
200
+ }
201
+
202
+ llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
203
+ return std::make_unique<llama_kv_cache_iswa_context>(this);
204
+ }
205
+
206
+ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
207
+ return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
208
+ }
209
+
210
+ bool llama_kv_cache_iswa::get_can_shift() const {
211
+ return kv_base->get_size() == kv_swa->get_size();
212
+ }
213
+
214
+ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
215
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
216
+ kv_base->state_write(io, seq_id, flags);
217
+ }
218
+
219
+ kv_swa->state_write(io, seq_id, flags);
220
+ }
221
+
222
+ void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
223
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
224
+ kv_base->state_read(io, seq_id, flags);
225
+ }
226
+
227
+ kv_swa->state_read(io, seq_id, flags);
228
+ }
229
+
230
+ llama_kv_cache * llama_kv_cache_iswa::get_base() const {
231
+ return kv_base.get();
232
+ }
233
+
234
+ llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
235
+ return kv_swa.get();
236
+ }
237
+
238
+ //
239
+ // llama_kv_cache_iswa_context
240
+ //
241
+
242
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
243
+
244
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
245
+ llama_kv_cache_iswa * kv) :
246
+ ctx_base(kv->get_base()->init_full()),
247
+ ctx_swa (kv->get_swa ()->init_full()),
248
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
249
+ }
250
+
251
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
252
+ llama_kv_cache_iswa * kv,
253
+ llama_context * lctx,
254
+ bool optimize) :
255
+ ctx_base(kv->get_base()->init_update(lctx, optimize)),
256
+ ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
257
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
258
+ }
259
+
260
+ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
261
+ llama_kv_cache_iswa * kv,
262
+ slot_info_vec_t sinfos_base,
263
+ slot_info_vec_t sinfos_swa,
264
+ std::vector<llama_ubatch> ubatches) :
265
+ ubatches(std::move(ubatches)),
266
+ // note: here we copy the ubatches. not sure if this is ideal
267
+ ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
268
+ ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
269
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
270
+ }
271
+
272
+ llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
273
+
274
+ bool llama_kv_cache_iswa_context::next() {
275
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
276
+
277
+ ctx_base->next();
278
+ ctx_swa ->next();
279
+
280
+ if (++i_next >= ubatches.size()) {
281
+ return false;
282
+ }
283
+
284
+ return true;
285
+ }
286
+
287
+ bool llama_kv_cache_iswa_context::apply() {
288
+ assert(!llama_memory_status_is_fail(status));
289
+
290
+ bool res = true;
291
+
292
+ res = res & ctx_base->apply();
293
+ res = res & ctx_swa ->apply();
294
+
295
+ return res;
296
+ }
297
+
298
+ llama_memory_status llama_kv_cache_iswa_context::get_status() const {
299
+ return status;
300
+ }
301
+
302
+ const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
303
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
304
+
305
+ return ubatches[i_next];
306
+ }
307
+
308
+ const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
309
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
310
+
311
+ return static_cast<const llama_kv_cache_context *>(ctx_base.get());
312
+ }
313
+
314
+ const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
315
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
316
+
317
+ return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
318
+ }
@@ -0,0 +1,135 @@
1
+ #pragma once
2
+
3
+ #include "llama-kv-cache.h"
4
+
5
+ #include <vector>
6
+
7
+ //
8
+ // llama_kv_cache_iswa
9
+ //
10
+
11
+ // utilizes two instances of llama_kv_cache
12
+ // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
+
14
+ class llama_kv_cache_iswa : public llama_memory_i {
15
+ public:
16
+ llama_kv_cache_iswa(
17
+ const llama_model & model,
18
+ lm_ggml_type type_k,
19
+ lm_ggml_type type_v,
20
+ bool v_trans,
21
+ bool offload,
22
+ bool swa_full,
23
+ bool unified,
24
+ uint32_t kv_size,
25
+ uint32_t n_seq_max,
26
+ uint32_t n_ubatch,
27
+ uint32_t n_pad,
28
+ const layer_filter_cb & filter,
29
+ const layer_reuse_cb & reuse);
30
+
31
+ ~llama_kv_cache_iswa() = default;
32
+
33
+ //
34
+ // llama_memory_i
35
+ //
36
+
37
+ llama_memory_context_ptr init_batch(
38
+ llama_batch_allocr & balloc,
39
+ uint32_t n_ubatch,
40
+ bool embd_all) override;
41
+
42
+ llama_memory_context_ptr init_full() override;
43
+
44
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
45
+
46
+ bool get_can_shift() const override;
47
+
48
+ void clear(bool data) override;
49
+
50
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
51
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
52
+ void seq_keep(llama_seq_id seq_id) override;
53
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
54
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
55
+
56
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
57
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
58
+
59
+ // state write/load
60
+
61
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
62
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
63
+
64
+ //
65
+ // llama_kv_cache_iswa specific API
66
+ //
67
+
68
+ llama_kv_cache * get_base() const;
69
+ llama_kv_cache * get_swa () const;
70
+
71
+ private:
72
+ const llama_hparams & hparams;
73
+
74
+ const bool unified;
75
+
76
+ std::unique_ptr<llama_kv_cache> kv_base;
77
+ std::unique_ptr<llama_kv_cache> kv_swa;
78
+ };
79
+
80
+ class llama_kv_cache_iswa_context : public llama_memory_context_i {
81
+ public:
82
+ using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
83
+
84
+ // used for errors
85
+ llama_kv_cache_iswa_context(llama_memory_status status);
86
+
87
+ // used to create a full-cache context
88
+ llama_kv_cache_iswa_context(
89
+ llama_kv_cache_iswa * kv);
90
+
91
+ // used to create an update context
92
+ llama_kv_cache_iswa_context(
93
+ llama_kv_cache_iswa * kv,
94
+ llama_context * lctx,
95
+ bool optimize);
96
+
97
+ // used to create a batch processing context from a batch
98
+ llama_kv_cache_iswa_context(
99
+ llama_kv_cache_iswa * kv,
100
+ slot_info_vec_t sinfos_base,
101
+ slot_info_vec_t sinfos_swa,
102
+ std::vector<llama_ubatch> ubatches);
103
+
104
+ virtual ~llama_kv_cache_iswa_context();
105
+
106
+ //
107
+ // llama_memory_context_i
108
+ //
109
+
110
+ bool next() override;
111
+ bool apply() override;
112
+
113
+ llama_memory_status get_status() const override;
114
+ const llama_ubatch & get_ubatch() const override;
115
+
116
+ //
117
+ // llama_kv_cache_iswa_context specific API
118
+ //
119
+
120
+ const llama_kv_cache_context * get_base() const;
121
+ const llama_kv_cache_context * get_swa() const;
122
+
123
+ private:
124
+ //llama_kv_cache_iswa * kv;
125
+
126
+ // the index of the next ubatch to process
127
+ size_t i_next = 0;
128
+
129
+ std::vector<llama_ubatch> ubatches;
130
+
131
+ const llama_memory_context_ptr ctx_base;
132
+ const llama_memory_context_ptr ctx_swa;
133
+
134
+ const llama_memory_status status;
135
+ };