@novastera-oss/llamarn 0.2.5 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +12 -8
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +46 -65
  13. package/cpp/LlamaCppModel.h +5 -0
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/README.md +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +5 -8
  17. package/cpp/llama.cpp/common/arg.cpp +8 -6
  18. package/cpp/llama.cpp/common/chat-parser.cpp +4 -3
  19. package/cpp/llama.cpp/common/chat-parser.h +2 -1
  20. package/cpp/llama.cpp/common/chat.cpp +4 -4
  21. package/cpp/llama.cpp/common/common.cpp +2 -0
  22. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  23. package/cpp/llama.cpp/common/json-partial.h +2 -1
  24. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  25. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +31 -28
  27. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  28. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +2 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  30. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +23 -0
  32. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +19 -8
  35. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  39. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml.c +9 -2
  41. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  42. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  43. package/cpp/llama.cpp/include/llama.h +12 -8
  44. package/cpp/llama.cpp/src/CMakeLists.txt +3 -0
  45. package/cpp/llama.cpp/src/llama-batch.cpp +19 -12
  46. package/cpp/llama.cpp/src/llama-batch.h +15 -10
  47. package/cpp/llama.cpp/src/llama-context.cpp +226 -151
  48. package/cpp/llama.cpp/src/llama-context.h +25 -8
  49. package/cpp/llama.cpp/src/llama-graph.cpp +50 -47
  50. package/cpp/llama.cpp/src/llama-graph.h +25 -24
  51. package/cpp/llama.cpp/src/llama-kv-cache-recurrent.cpp +1132 -0
  52. package/cpp/llama.cpp/src/llama-kv-cache-recurrent.h +191 -0
  53. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +249 -0
  54. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +136 -0
  55. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1717 -0
  56. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +278 -0
  57. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2746
  58. package/cpp/llama.cpp/src/llama-kv-cache.h +14 -472
  59. package/cpp/llama.cpp/src/llama-kv-cells.h +37 -6
  60. package/cpp/llama.cpp/src/llama-memory.h +44 -0
  61. package/cpp/llama.cpp/src/llama-model.cpp +23 -16
  62. package/cpp/llama.cpp/src/llama-vocab.cpp +7 -2
  63. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  64. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  65. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  66. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  67. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  68. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  69. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  70. package/cpp/rn-completion.cpp +63 -8
  71. package/cpp/rn-utils.hpp +8 -1
  72. package/ios/include/common/minja/chat-template.hpp +1 -1
  73. package/ios/include/common/minja/minja.hpp +1 -1
  74. package/ios/include/json-schema-to-grammar.h +4 -4
  75. package/ios/include/llama.h +12 -8
  76. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  77. package/ios/libs/llama.xcframework/Info.plist +22 -22
  78. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  79. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4617
  80. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  81. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +12 -8
  82. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  83. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  84. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
  85. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3557
  86. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  87. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  88. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  89. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  90. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4638
  91. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3624 -3559
  92. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  93. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +12 -8
  94. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  95. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +12 -8
  96. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  97. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  98. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +12 -8
  99. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  100. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  101. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  102. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4689 -4616
  103. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  104. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +12 -8
  105. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  106. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  107. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4710 -4637
  108. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3622 -3556
  109. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  110. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  111. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  112. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  113. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4725 -4653
  114. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  115. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +12 -8
  116. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  117. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  118. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4746 -4674
  119. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3652 -3587
  120. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  121. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +12 -8
  122. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  123. package/package.json +1 -1
@@ -18,6 +18,9 @@ struct llama_kv_cache;
18
18
  class llama_io_read_i;
19
19
  class llama_io_write_i;
20
20
 
21
+ class llama_memory_i;
22
+ class llama_memory_state_i;
23
+
21
24
  struct llama_context {
22
25
  // init scheduler and compute buffers, reserve worst-case graphs
23
26
  llama_context(
@@ -47,7 +50,9 @@ struct llama_context {
47
50
  llama_kv_cache * get_kv_self();
48
51
  const llama_kv_cache * get_kv_self() const;
49
52
 
50
- void kv_self_update();
53
+ // return true of the KV cache was updated
54
+ // TODO: remove
55
+ bool kv_self_update();
51
56
 
52
57
  enum llama_pooling_type pooling_type() const;
53
58
 
@@ -88,6 +93,16 @@ struct llama_context {
88
93
  int32_t il_start,
89
94
  int32_t il_end);
90
95
 
96
+ // process a single ubatch with a specific graph type
97
+ // if memory_state is provided, it will be applied first to the context's memory
98
+ // ret contains the status of the graph computation
99
+ // returns nullptr only if ret != GGML_STATUS_SUCCESS
100
+ llm_graph_result_ptr process_ubatch(
101
+ const llama_ubatch & ubatch,
102
+ llm_graph_type gtype,
103
+ llama_memory_state_i * mstate,
104
+ ggml_status & ret);
105
+
91
106
  int encode(llama_batch & inp_batch);
92
107
  int decode(llama_batch & inp_batch);
93
108
 
@@ -180,16 +195,18 @@ public:
180
195
  ggml_cgraph * graph_init();
181
196
 
182
197
  // returns the result of ggml_backend_sched_graph_compute_async execution
183
- ggml_status graph_compute(
184
- ggml_cgraph * gf,
185
- bool batched);
198
+ ggml_status graph_compute(ggml_cgraph * gf, bool batched);
199
+
200
+ // reserve a graph with a dummy ubatch of the specified size
201
+ ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
186
202
 
187
203
  private:
188
204
  llm_graph_result_ptr graph_build(
189
- ggml_context * ctx,
190
- ggml_cgraph * gf,
191
- const llama_ubatch & ubatch,
192
- llm_graph_type gtype);
205
+ ggml_context * ctx,
206
+ ggml_cgraph * gf,
207
+ const llama_ubatch & ubatch,
208
+ llm_graph_type gtype,
209
+ const llama_memory_state_i * mstate);
193
210
 
194
211
  llm_graph_cb graph_get_cb() const;
195
212
 
@@ -3,7 +3,10 @@
3
3
  #include "llama-impl.h"
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
- #include "llama-kv-cache.h"
6
+
7
+ #include "llama-kv-cache-unified.h"
8
+ #include "llama-kv-cache-unified-iswa.h"
9
+ #include "llama-kv-cache-recurrent.h"
7
10
 
8
11
  #include <cassert>
9
12
  #include <cmath>
@@ -83,7 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
83
86
 
84
87
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
88
  if (pos_bucket) {
86
- kv_self->set_input_pos_bucket(pos_bucket, ubatch);
89
+ kv_state->set_input_pos_bucket(pos_bucket, ubatch);
87
90
  }
88
91
  }
89
92
 
@@ -234,7 +237,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
234
237
  void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
235
238
  GGML_UNUSED(ubatch);
236
239
 
237
- const int64_t n_kv = kv_self->n;
240
+ const int64_t n_kv = kv_state->get_n_kv();
238
241
 
239
242
  if (s_copy) {
240
243
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -242,7 +245,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242
245
 
243
246
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244
247
  for (uint32_t i = 0; i < n_kv; ++i) {
245
- data[i] = kv_self->s_copy(i);
248
+ data[i] = kv_state->s_copy(i);
246
249
  }
247
250
  }
248
251
  }
@@ -250,7 +253,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
250
253
  void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251
254
  GGML_UNUSED(ubatch);
252
255
 
253
- const int64_t n_kv = kv_self->n;
256
+ const int64_t n_kv = kv_state->get_n_kv();
254
257
 
255
258
  if (s_mask) {
256
259
  GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
@@ -258,7 +261,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
258
261
 
259
262
  // clear unused states
260
263
  for (int i = 0; i < n_kv; ++i) {
261
- data[i] = kv_self->s_mask(i);
264
+ data[i] = kv_state->s_mask(i);
262
265
  }
263
266
  }
264
267
  }
@@ -362,17 +365,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
365
 
363
366
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
367
  if (self_kq_mask) {
365
- kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
368
+ kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
369
  }
367
370
  }
368
371
 
369
372
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
373
  if (self_kq_mask) {
371
- kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
374
+ kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
375
  }
373
376
 
374
377
  if (self_kq_mask_swa) {
375
- kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
378
+ kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376
379
  }
377
380
  }
378
381
 
@@ -448,7 +451,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448
451
  backend_cpu (params.backend_cpu),
449
452
  cvec (params.cvec),
450
453
  loras (params.loras),
451
- memory (params.memory),
454
+ mstate (params.mstate),
452
455
  cross (params.cross),
453
456
  cb_func (params.cb),
454
457
  res (std::make_unique<llm_graph_result>()) {
@@ -954,11 +957,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
954
957
  }
955
958
 
956
959
  ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
960
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
958
961
 
959
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
962
+ auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
960
963
 
961
- const auto n_kv = kv_self->n;
964
+ const auto n_kv = kv_state->get_n_kv();
962
965
 
963
966
  auto & cur = inp->s_copy;
964
967
 
@@ -971,11 +974,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
971
974
  }
972
975
 
973
976
  ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
977
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
975
978
 
976
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
979
+ auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
977
980
 
978
- const auto n_kv = kv_self->n;
981
+ const auto n_kv = kv_state->get_n_kv();
979
982
 
980
983
  auto & cur = inp->s_mask;
981
984
 
@@ -1025,11 +1028,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
1028
  }
1026
1029
 
1027
1030
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1031
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1029
1032
 
1030
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1033
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
1031
1034
 
1032
- const auto n_kv = kv_self->get_n();
1035
+ const auto n_kv = kv_state->get_n_kv();
1033
1036
 
1034
1037
  auto & cur = inp->pos_bucket;
1035
1038
 
@@ -1231,14 +1234,14 @@ ggml_tensor * llm_graph_context::build_attn(
1231
1234
  }
1232
1235
 
1233
1236
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1237
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1235
1238
 
1236
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1239
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1237
1240
 
1238
1241
  {
1239
1242
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1240
1243
 
1241
- const auto n_kv = kv_self->get_n();
1244
+ const auto n_kv = kv_state->get_n_kv();
1242
1245
 
1243
1246
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
1247
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1268,19 +1271,19 @@ ggml_tensor * llm_graph_context::build_attn(
1268
1271
  ggml_build_forward_expand(gf, k_cur);
1269
1272
  ggml_build_forward_expand(gf, v_cur);
1270
1273
 
1271
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1274
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1272
1275
 
1273
1276
  // store to KV cache
1274
1277
  {
1275
- ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
- ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1278
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1279
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1277
1280
  }
1278
1281
 
1279
1282
  const auto & kq_mask = inp->get_kq_mask();
1280
1283
 
1281
1284
  ggml_tensor * q = q_cur;
1282
- ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
- ggml_tensor * v = kv_self->get_v(ctx0, il);
1285
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1286
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1284
1287
 
1285
1288
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
1289
  cb(cur, "kqv_out", il);
@@ -1301,12 +1304,12 @@ ggml_tensor * llm_graph_context::build_attn(
1301
1304
  }
1302
1305
 
1303
1306
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1304
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1307
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1305
1308
 
1306
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1309
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1307
1310
 
1308
1311
  {
1309
- const auto n_kv = kv_self->get_kv_base()->get_n();
1312
+ const auto n_kv = kv_state->get_base()->get_n_kv();
1310
1313
 
1311
1314
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1312
1315
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1318,7 +1321,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1318
1321
  {
1319
1322
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1320
1323
 
1321
- const auto n_kv = kv_self->get_kv_swa()->get_n();
1324
+ const auto n_kv = kv_state->get_swa()->get_n_kv();
1322
1325
 
1323
1326
  inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1324
1327
  //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1348,23 +1351,23 @@ ggml_tensor * llm_graph_context::build_attn(
1348
1351
  ggml_build_forward_expand(gf, k_cur);
1349
1352
  ggml_build_forward_expand(gf, v_cur);
1350
1353
 
1351
- const bool is_swa = hparams.is_swa(il);
1354
+ const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1352
1355
 
1353
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1356
+ const bool is_swa = hparams.is_swa(il);
1354
1357
 
1355
- const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1358
+ const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1356
1359
 
1357
1360
  // store to KV cache
1358
1361
  {
1359
- ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360
- ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1362
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1363
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1361
1364
  }
1362
1365
 
1363
1366
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1364
1367
 
1365
1368
  ggml_tensor * q = q_cur;
1366
- ggml_tensor * k = kv->get_k(ctx0, il);
1367
- ggml_tensor * v = kv->get_v(ctx0, il);
1369
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1370
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1368
1371
 
1369
1372
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1370
1373
  cb(cur, "kqv_out", il);
@@ -1446,12 +1449,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
1446
1449
  ggml_tensor * state_mask,
1447
1450
  int32_t n_state,
1448
1451
  int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1452
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1450
1453
 
1451
- const auto n_kv = kv_self->n;
1452
- const auto kv_head = kv_self->head;
1454
+ const auto n_kv = kv_state->get_n_kv();
1455
+ const auto kv_head = kv_state->get_head();
1453
1456
 
1454
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1457
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
1455
1458
 
1456
1459
  // copy states
1457
1460
  // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
@@ -1478,13 +1481,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1478
1481
  ggml_tensor * state_mask,
1479
1482
  const llama_ubatch & ubatch,
1480
1483
  int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1484
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1482
1485
 
1483
1486
  const auto token_shift_count = hparams.token_shift_count;
1484
1487
 
1485
1488
  const int64_t n_seqs = ubatch.n_seqs;
1486
1489
 
1487
- ggml_tensor * token_shift_all = kv_self->k_l[il];
1490
+ ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1488
1491
 
1489
1492
  ggml_tensor * token_shift = build_copy_mask_state(
1490
1493
  gf, token_shift_all, state_copy, state_mask,
@@ -1499,19 +1502,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1502
  ggml_tensor * token_shift,
1500
1503
  const llama_ubatch & ubatch,
1501
1504
  int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1505
+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1503
1506
 
1504
1507
  const auto token_shift_count = hparams.token_shift_count;
1505
1508
  const auto n_embd = hparams.n_embd;
1506
1509
 
1507
1510
  const int64_t n_seqs = ubatch.n_seqs;
1508
1511
 
1509
- const auto kv_head = kv_self->head;
1512
+ const auto kv_head = kv_state->get_head();
1510
1513
 
1511
1514
  return ggml_cpy(
1512
1515
  ctx0,
1513
1516
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1514
- ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
1517
+ ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
1515
1518
  );
1516
1519
  }
1517
1520
 
@@ -17,10 +17,11 @@ struct ggml_tensor;
17
17
  struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- class llama_memory_i;
21
- class llama_kv_cache_unified;
22
- class llama_kv_cache_unified_iswa;
23
- class llama_kv_cache_recurrent;
20
+ class llama_memory_state_i;
21
+
22
+ class llama_kv_cache_unified_state;
23
+ class llama_kv_cache_unified_iswa_state;
24
+ class llama_kv_cache_recurrent_state;
24
25
 
25
26
  // certain models (typically multi-modal) can produce different types of graphs
26
27
  enum llm_graph_type {
@@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
133
134
  public:
134
135
  llm_graph_input_pos_bucket_kv(
135
136
  const llama_hparams & hparams,
136
- const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
137
+ const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
137
138
  virtual ~llm_graph_input_pos_bucket_kv() = default;
138
139
 
139
140
  void set_input(const llama_ubatch * ubatch) override;
@@ -141,7 +142,7 @@ public:
141
142
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
142
143
 
143
144
  const llama_hparams & hparams;
144
- const llama_kv_cache_unified * kv_self;
145
+ const llama_kv_cache_unified_state * kv_state;
145
146
  };
146
147
 
147
148
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -188,26 +189,26 @@ public:
188
189
 
189
190
  class llm_graph_input_s_copy : public llm_graph_input_i {
190
191
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
192
193
  virtual ~llm_graph_input_s_copy() = default;
193
194
 
194
195
  void set_input(const llama_ubatch * ubatch) override;
195
196
 
196
197
  ggml_tensor * s_copy; // I32 [kv_size]
197
198
 
198
- const llama_kv_cache_recurrent * kv_self;
199
+ const llama_kv_cache_recurrent_state * kv_state;
199
200
  };
200
201
 
201
202
  class llm_graph_input_s_mask : public llm_graph_input_i {
202
203
  public:
203
- llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
204
205
  virtual ~llm_graph_input_s_mask() = default;
205
206
 
206
207
  void set_input(const llama_ubatch * ubatch) override;
207
208
 
208
209
  ggml_tensor * s_mask; // F32 [1, n_kv]
209
210
 
210
- const llama_kv_cache_recurrent * kv_self;
211
+ const llama_kv_cache_recurrent_state * kv_state;
211
212
  };
212
213
 
213
214
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -247,10 +248,10 @@ public:
247
248
  llm_graph_input_attn_kv_unified(
248
249
  const llama_hparams & hparams,
249
250
  const llama_cparams & cparams,
250
- const llama_kv_cache_unified * kv_self) :
251
+ const llama_kv_cache_unified_state * kv_state) :
251
252
  hparams(hparams),
252
253
  cparams(cparams),
253
- kv_self(kv_self) {
254
+ kv_state(kv_state) {
254
255
  }
255
256
  ~llm_graph_input_attn_kv_unified() = default;
256
257
 
@@ -264,7 +265,7 @@ public:
264
265
  const llama_hparams & hparams;
265
266
  const llama_cparams & cparams;
266
267
 
267
- const llama_kv_cache_unified * kv_self;
268
+ const llama_kv_cache_unified_state * kv_state;
268
269
  };
269
270
 
270
271
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -272,10 +273,10 @@ public:
272
273
  llm_graph_input_attn_kv_unified_iswa(
273
274
  const llama_hparams & hparams,
274
275
  const llama_cparams & cparams,
275
- const llama_kv_cache_unified_iswa * kv_self) :
276
+ const llama_kv_cache_unified_iswa_state * kv_state) :
276
277
  hparams(hparams),
277
278
  cparams(cparams),
278
- kv_self(kv_self) {
279
+ kv_state(kv_state) {
279
280
  }
280
281
  ~llm_graph_input_attn_kv_unified_iswa() = default;
281
282
 
@@ -292,7 +293,7 @@ public:
292
293
  const llama_hparams & hparams;
293
294
  const llama_cparams & cparams;
294
295
 
295
- const llama_kv_cache_unified_iswa * kv_self;
296
+ const llama_kv_cache_unified_iswa_state * kv_state;
296
297
  };
297
298
 
298
299
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -383,10 +384,10 @@ struct llm_graph_params {
383
384
  ggml_backend_sched_t sched;
384
385
  ggml_backend_t backend_cpu;
385
386
 
386
- const llama_adapter_cvec * cvec;
387
- const llama_adapter_loras * loras;
388
- const llama_memory_i * memory;
389
- const llama_cross * cross;
387
+ const llama_adapter_cvec * cvec;
388
+ const llama_adapter_loras * loras;
389
+ const llama_memory_state_i * mstate;
390
+ const llama_cross * cross;
390
391
 
391
392
  int32_t n_outputs;
392
393
 
@@ -435,10 +436,10 @@ struct llm_graph_context {
435
436
 
436
437
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
437
438
 
438
- const llama_adapter_cvec * cvec;
439
- const llama_adapter_loras * loras;
440
- const llama_memory_i * memory;
441
- const llama_cross * cross;
439
+ const llama_adapter_cvec * cvec;
440
+ const llama_adapter_loras * loras;
441
+ const llama_memory_state_i * mstate;
442
+ const llama_cross * cross;
442
443
 
443
444
  const llm_graph_cb & cb_func;
444
445