@novastera-oss/llamarn 0.2.4 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +12 -8
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +46 -65
  13. package/cpp/LlamaCppModel.h +5 -0
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/README.md +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +5 -8
  17. package/cpp/llama.cpp/common/arg.cpp +8 -6
  18. package/cpp/llama.cpp/common/chat-parser.cpp +4 -3
  19. package/cpp/llama.cpp/common/chat-parser.h +2 -1
  20. package/cpp/llama.cpp/common/chat.cpp +4 -4
  21. package/cpp/llama.cpp/common/common.cpp +2 -0
  22. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  23. package/cpp/llama.cpp/common/json-partial.h +2 -1
  24. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  25. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +31 -28
  27. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  28. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +2 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  30. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +23 -0
  32. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +19 -8
  35. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  39. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml.c +9 -2
  41. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  42. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  43. package/cpp/llama.cpp/include/llama.h +12 -8
  44. package/cpp/llama.cpp/src/CMakeLists.txt +3 -0
  45. package/cpp/llama.cpp/src/llama-batch.cpp +19 -12
  46. package/cpp/llama.cpp/src/llama-batch.h +15 -10
  47. package/cpp/llama.cpp/src/llama-context.cpp +226 -151
  48. package/cpp/llama.cpp/src/llama-context.h +25 -8
  49. package/cpp/llama.cpp/src/llama-graph.cpp +50 -47
  50. package/cpp/llama.cpp/src/llama-graph.h +25 -24
  51. package/cpp/llama.cpp/src/llama-kv-cache-recurrent.cpp +1132 -0
  52. package/cpp/llama.cpp/src/llama-kv-cache-recurrent.h +191 -0
  53. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +249 -0
  54. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +136 -0
  55. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1717 -0
  56. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +278 -0
  57. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2746
  58. package/cpp/llama.cpp/src/llama-kv-cache.h +14 -472
  59. package/cpp/llama.cpp/src/llama-kv-cells.h +37 -6
  60. package/cpp/llama.cpp/src/llama-memory.h +44 -0
  61. package/cpp/llama.cpp/src/llama-model.cpp +23 -16
  62. package/cpp/llama.cpp/src/llama-vocab.cpp +7 -2
  63. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  64. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  65. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  66. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  67. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  68. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  69. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  70. package/cpp/rn-completion.cpp +101 -52
  71. package/cpp/rn-utils.hpp +8 -1
  72. package/ios/include/common/minja/chat-template.hpp +1 -1
  73. package/ios/include/common/minja/minja.hpp +1 -1
  74. package/ios/include/json-schema-to-grammar.h +4 -4
  75. package/ios/include/llama.h +12 -8
  76. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  77. package/ios/libs/llama.xcframework/Info.plist +22 -22
  78. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  79. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4617
  80. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  81. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +12 -8
  82. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  83. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  84. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
  85. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3557
  86. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  87. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  88. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  89. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  90. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
  91. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3624 -3559
  92. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  93. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +12 -8
  94. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  95. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +12 -8
  96. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  97. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  98. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +12 -8
  99. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  100. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  101. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  102. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4616
  103. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  104. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +12 -8
  105. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  106. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  107. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4637
  108. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3556
  109. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  110. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  111. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  112. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  113. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4725 -4653
  114. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  115. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +12 -8
  116. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  117. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  118. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4746 -4674
  119. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3652 -3587
  120. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  121. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  122. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  123. package/package.json +1 -1
@@ -0,0 +1,191 @@
1
+ #pragma once
2
+
3
+ #include "llama-batch.h"
4
+ #include "llama-graph.h"
5
+ #include "llama-kv-cache.h"
6
+
7
+ #include <set>
8
+ #include <vector>
9
+
10
+ //
11
+ // llama_kv_cache_recurrent
12
+ //
13
+
14
+ // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
15
+ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
16
+ class llama_kv_cache_recurrent : public llama_kv_cache {
17
+ public:
18
+ llama_kv_cache_recurrent(
19
+ const llama_model & model,
20
+ ggml_type type_k,
21
+ ggml_type type_v,
22
+ bool offload,
23
+ uint32_t kv_size,
24
+ uint32_t n_seq_max);
25
+
26
+ ~llama_kv_cache_recurrent() = default;
27
+
28
+ //
29
+ // llama_memory_i
30
+ //
31
+
32
+ void clear() override;
33
+
34
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
35
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
36
+ void seq_keep(llama_seq_id seq_id) override;
37
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
38
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
39
+
40
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
41
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
42
+
43
+ //
44
+ // llama_kv_cache
45
+ //
46
+
47
+ llama_memory_state_ptr init_batch(
48
+ const llama_batch & batch,
49
+ uint32_t n_ubatch,
50
+ bool embd_pooled,
51
+ bool logits_all) override;
52
+
53
+ llama_memory_state_ptr init_full() override;
54
+
55
+ bool update(llama_context & lctx) override;
56
+
57
+ void defrag_sched(float thold) override;
58
+
59
+ bool prepare(const std::vector<llama_ubatch> & ubatches);
60
+
61
+ // find a contiguous slot of kv cells and emplace the ubatch there
62
+ bool find_slot(const llama_ubatch & ubatch);
63
+
64
+ bool get_can_shift() const override;
65
+
66
+ // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
67
+ int32_t s_copy(int i) const;
68
+ float s_mask(int i) const;
69
+
70
+ // state write/load
71
+
72
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
73
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
74
+
75
+ uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
76
+ uint32_t size = 0; // total number of cells, shared across all sequences
77
+ uint32_t used = 0; // used cells (i.e. at least one seq_id)
78
+
79
+ // computed before each graph build
80
+ uint32_t n = 0;
81
+
82
+ // TODO: optimize for recurrent state needs
83
+ struct kv_cell {
84
+ llama_pos pos = -1;
85
+ int32_t src = -1; // used to copy states
86
+ int32_t tail = -1;
87
+
88
+ std::set<llama_seq_id> seq_id;
89
+
90
+ bool has_seq_id(const llama_seq_id & id) const {
91
+ return seq_id.find(id) != seq_id.end();
92
+ }
93
+
94
+ bool is_empty() const {
95
+ return seq_id.empty();
96
+ }
97
+
98
+ bool is_same_seq(const kv_cell & other) const {
99
+ return seq_id == other.seq_id;
100
+ }
101
+ };
102
+
103
+ std::vector<kv_cell> cells;
104
+
105
+ std::vector<ggml_tensor *> k_l; // per layer
106
+ std::vector<ggml_tensor *> v_l;
107
+
108
+ private:
109
+ //const llama_model & model;
110
+ const llama_hparams & hparams;
111
+
112
+ const uint32_t n_seq_max = 1;
113
+
114
+ std::vector<ggml_context_ptr> ctxs;
115
+ std::vector<ggml_backend_buffer_ptr> bufs;
116
+
117
+ size_t total_size() const;
118
+
119
+ size_t size_k_bytes() const;
120
+ size_t size_v_bytes() const;
121
+
122
+ void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
123
+ void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
124
+
125
+ bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
126
+ bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
127
+ };
128
+
129
+ class llama_kv_cache_recurrent_state : public llama_memory_state_i {
130
+ public:
131
+ // used for errors
132
+ llama_kv_cache_recurrent_state(llama_memory_status status);
133
+
134
+ // used to create a full-cache state
135
+ llama_kv_cache_recurrent_state(
136
+ llama_memory_status status,
137
+ llama_kv_cache_recurrent * kv);
138
+
139
+ // used to create a state from a batch
140
+ llama_kv_cache_recurrent_state(
141
+ llama_memory_status status,
142
+ llama_kv_cache_recurrent * kv,
143
+ llama_sbatch sbatch,
144
+ std::vector<llama_ubatch> ubatches);
145
+
146
+ virtual ~llama_kv_cache_recurrent_state();
147
+
148
+ //
149
+ // llama_memory_state_i
150
+ //
151
+
152
+ bool next() override;
153
+ bool apply() override;
154
+
155
+ std::vector<int64_t> & out_ids() override;
156
+
157
+ llama_memory_status get_status() const override;
158
+ const llama_ubatch & get_ubatch() const override;
159
+
160
+ //
161
+ // llama_kv_cache_recurrent_state specific API
162
+ //
163
+
164
+ uint32_t get_n_kv() const;
165
+ uint32_t get_head() const;
166
+ uint32_t get_size() const;
167
+
168
+ ggml_tensor * get_k_l(int32_t il) const;
169
+ ggml_tensor * get_v_l(int32_t il) const;
170
+
171
+ int32_t s_copy(int i) const;
172
+ float s_mask(int i) const;
173
+
174
+ private:
175
+ const llama_memory_status status;
176
+
177
+ llama_kv_cache_recurrent * kv;
178
+
179
+ llama_sbatch sbatch;
180
+
181
+ size_t i_next = 0;
182
+
183
+ std::vector<llama_ubatch> ubatches;
184
+
185
+ //
186
+ // data needed for building the compute graph for the current ubatch:
187
+ // TODO: extract all the state like `head` and `n` here
188
+ //
189
+
190
+ const bool is_full = false;
191
+ };
@@ -0,0 +1,249 @@
1
+ #include "llama-kv-cache-unified-iswa.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-model.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+
10
+ //
11
+ // llama_kv_cache_unified_iswa
12
+ //
13
+
14
+ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15
+ const llama_model & model,
16
+ ggml_type type_k,
17
+ ggml_type type_v,
18
+ bool v_trans,
19
+ bool offload,
20
+ bool swa_full,
21
+ uint32_t kv_size,
22
+ uint32_t n_seq_max,
23
+ uint32_t n_ubatch,
24
+ uint32_t n_pad) : hparams(model.hparams) {
25
+ llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
26
+ llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
27
+
28
+ const uint32_t size_base = kv_size;
29
+
30
+ uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31
+
32
+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
33
+ if (swa_full) {
34
+ LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
35
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
36
+
37
+ size_swa = size_base;
38
+ }
39
+
40
+ LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
41
+
42
+ kv_base = std::make_unique<llama_kv_cache_unified>(
43
+ model, std::move(filter_base), type_k, type_v,
44
+ v_trans, offload, size_base, n_seq_max, n_pad,
45
+ 0, LLAMA_SWA_TYPE_NONE);
46
+
47
+ LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
48
+
49
+ kv_swa = std::make_unique<llama_kv_cache_unified>(
50
+ model, std::move(filter_swa), type_k, type_v,
51
+ v_trans, offload, size_swa, n_seq_max, n_pad,
52
+ hparams.n_swa, hparams.swa_type);
53
+ }
54
+
55
+ void llama_kv_cache_unified_iswa::clear() {
56
+ kv_base->clear();
57
+ kv_swa ->clear();
58
+ }
59
+
60
+ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
61
+ bool res = true;
62
+
63
+ res = res & kv_base->seq_rm(seq_id, p0, p1);
64
+ res = res & kv_swa ->seq_rm(seq_id, p0, p1);
65
+
66
+ return res;
67
+ }
68
+
69
+ void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
70
+ kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
71
+ kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
72
+ }
73
+
74
+ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
75
+ kv_base->seq_keep(seq_id);
76
+ kv_swa ->seq_keep(seq_id);
77
+ }
78
+
79
+ void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
80
+ kv_base->seq_add(seq_id, p0, p1, shift);
81
+ kv_swa ->seq_add(seq_id, p0, p1, shift);
82
+ }
83
+
84
+ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
85
+ kv_base->seq_div(seq_id, p0, p1, d);
86
+ kv_swa ->seq_div(seq_id, p0, p1, d);
87
+ }
88
+
89
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
90
+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
91
+ return kv_swa->seq_pos_min(seq_id);
92
+ }
93
+
94
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
+ return kv_swa->seq_pos_max(seq_id);
96
+ }
97
+
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
99
+ GGML_UNUSED(embd_pooled);
100
+
101
+ // TODO: if we fail with split_simple, we should attempt different splitting strategies
102
+ // but to do that properly, we first have to refactor the batches to be more flexible
103
+
104
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
105
+
106
+ std::vector<llama_ubatch> ubatches;
107
+
108
+ while (sbatch.n_tokens > 0) {
109
+ auto ubatch = sbatch.split_simple(n_ubatch);
110
+
111
+ ubatches.push_back(ubatch);
112
+ }
113
+
114
+ auto heads_base = kv_base->prepare(ubatches);
115
+ if (heads_base.empty()) {
116
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
117
+ }
118
+
119
+ auto heads_swa = kv_swa->prepare(ubatches);
120
+ if (heads_swa.empty()) {
121
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
122
+ }
123
+
124
+ assert(heads_base.size() == heads_swa.size());
125
+
126
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
127
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
+ }
129
+
130
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
132
+ }
133
+
134
+ bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135
+ bool res = false;
136
+
137
+ res = res | kv_base->update(lctx);
138
+ res = res | kv_swa ->update(lctx);
139
+
140
+ return res;
141
+ }
142
+
143
+ void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144
+ kv_base->defrag_sched(thold);
145
+ kv_swa ->defrag_sched(thold);
146
+ }
147
+
148
+ bool llama_kv_cache_unified_iswa::get_can_shift() const {
149
+ return kv_base->get_size() == kv_swa->get_size();
150
+ }
151
+
152
+ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
153
+ kv_base->state_write(io, seq_id);
154
+ kv_swa ->state_write(io, seq_id);
155
+ }
156
+
157
+ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
158
+ kv_base->state_read(io, seq_id);
159
+ kv_swa ->state_read(io, seq_id);
160
+ }
161
+
162
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
163
+ return kv_base.get();
164
+ }
165
+
166
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
167
+ return kv_swa.get();
168
+ }
169
+
170
+ //
171
+ // llama_kv_cache_unified_iswa_state
172
+ //
173
+
174
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
175
+
176
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
177
+ llama_memory_status status,
178
+ llama_kv_cache_unified_iswa * kv) : status(status) {
179
+ state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180
+ state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
181
+ }
182
+
183
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184
+ llama_memory_status status,
185
+ llama_kv_cache_unified_iswa * kv,
186
+ llama_sbatch sbatch,
187
+ std::vector<uint32_t> heads_base,
188
+ std::vector<uint32_t> heads_swa,
189
+ std::vector<llama_ubatch> ubatches)
190
+ : status(status),
191
+ sbatch(std::move(sbatch)),
192
+ ubatches(std::move(ubatches)) {
193
+ // note: here we copy the ubatches. not sure if this is ideal
194
+ state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195
+ state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196
+ }
197
+
198
+ llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
199
+
200
+ bool llama_kv_cache_unified_iswa_state::next() {
201
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
202
+
203
+ state_base->next();
204
+ state_swa ->next();
205
+
206
+ if (++i_next >= ubatches.size()) {
207
+ return false;
208
+ }
209
+
210
+ return true;
211
+ }
212
+
213
+ bool llama_kv_cache_unified_iswa_state::apply() {
214
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
215
+
216
+ bool res = true;
217
+
218
+ res = res & state_base->apply();
219
+ res = res & state_swa ->apply();
220
+
221
+ return res;
222
+ }
223
+
224
+ std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
225
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
226
+
227
+ return sbatch.out_ids;
228
+ }
229
+
230
+ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
231
+ return status;
232
+ }
233
+
234
+ const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
235
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
236
+ return ubatches[i_next];
237
+ }
238
+
239
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
240
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
241
+
242
+ return state_base.get();
243
+ }
244
+
245
+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
246
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247
+
248
+ return state_swa.get();
249
+ }
@@ -0,0 +1,136 @@
1
+ #pragma once
2
+
3
+ #include "llama-kv-cache-unified.h"
4
+
5
+ #include <vector>
6
+
7
+ //
8
+ // llama_kv_cache_unified_iswa
9
+ //
10
+
11
+ // utilizes two instances of llama_kv_cache_unified
12
+ // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
+
14
+ class llama_kv_cache_unified_iswa : public llama_kv_cache {
15
+ public:
16
+ llama_kv_cache_unified_iswa(
17
+ const llama_model & model,
18
+ ggml_type type_k,
19
+ ggml_type type_v,
20
+ bool v_trans,
21
+ bool offload,
22
+ bool swa_full,
23
+ uint32_t kv_size,
24
+ uint32_t n_seq_max,
25
+ uint32_t n_ubatch,
26
+ uint32_t n_pad);
27
+
28
+ ~llama_kv_cache_unified_iswa() = default;
29
+
30
+ //
31
+ // llama_memory_i
32
+ //
33
+
34
+ void clear() override;
35
+
36
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
37
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
38
+ void seq_keep(llama_seq_id seq_id) override;
39
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
40
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
41
+
42
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
43
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
44
+
45
+ //
46
+ // llama_kv_cache
47
+ //
48
+
49
+ llama_memory_state_ptr init_batch(
50
+ const llama_batch & batch,
51
+ uint32_t n_ubatch,
52
+ bool embd_pooled,
53
+ bool logits_all) override;
54
+
55
+ llama_memory_state_ptr init_full() override;
56
+
57
+ bool update(llama_context & lctx) override;
58
+
59
+ void defrag_sched(float thold) override;
60
+
61
+ bool get_can_shift() const override;
62
+
63
+ // state write/load
64
+
65
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
66
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
67
+
68
+ //
69
+ // llama_kv_cache_unified_iswa specific API
70
+ //
71
+
72
+ llama_kv_cache_unified * get_base() const;
73
+ llama_kv_cache_unified * get_swa () const;
74
+
75
+ private:
76
+ const llama_hparams & hparams;
77
+
78
+ std::unique_ptr<llama_kv_cache_unified> kv_base;
79
+ std::unique_ptr<llama_kv_cache_unified> kv_swa;
80
+ };
81
+
82
+ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
83
+ public:
84
+ // used for errors
85
+ llama_kv_cache_unified_iswa_state(llama_memory_status status);
86
+
87
+ // used to create a full-cache state
88
+ llama_kv_cache_unified_iswa_state(
89
+ llama_memory_status status,
90
+ llama_kv_cache_unified_iswa * kv);
91
+
92
+ // used to create a state from a batch
93
+ llama_kv_cache_unified_iswa_state(
94
+ llama_memory_status status,
95
+ llama_kv_cache_unified_iswa * kv,
96
+ llama_sbatch sbatch,
97
+ std::vector<uint32_t> heads_base,
98
+ std::vector<uint32_t> heads_swa,
99
+ std::vector<llama_ubatch> ubatches);
100
+
101
+ virtual ~llama_kv_cache_unified_iswa_state();
102
+
103
+ //
104
+ // llama_memory_state_i
105
+ //
106
+
107
+ bool next() override;
108
+ bool apply() override;
109
+
110
+ std::vector<int64_t> & out_ids() override;
111
+
112
+ llama_memory_status get_status() const override;
113
+ const llama_ubatch & get_ubatch() const override;
114
+
115
+ //
116
+ // llama_kv_cache_unified_iswa_state specific API
117
+ //
118
+
119
+ const llama_kv_cache_unified_state * get_base() const;
120
+ const llama_kv_cache_unified_state * get_swa() const;
121
+
122
+ private:
123
+ const llama_memory_status status;
124
+
125
+ //llama_kv_cache_unified_iswa * kv;
126
+
127
+ llama_sbatch sbatch;
128
+
129
+ // the index of the next ubatch to process
130
+ size_t i_next = 0;
131
+
132
+ std::vector<llama_ubatch> ubatches;
133
+
134
+ std::unique_ptr<llama_kv_cache_unified_state> state_base;
135
+ std::unique_ptr<llama_kv_cache_unified_state> state_swa;
136
+ };