cui-llama.rn 1.6.1 → 1.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
  3. package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +38 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
@@ -23,32 +23,21 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
23
23
  }
24
24
 
25
25
  llama_kv_cache_unified::llama_kv_cache_unified(
26
- const llama_model & model,
27
- lm_ggml_type type_k,
28
- lm_ggml_type type_v,
29
- bool v_trans,
30
- bool offload,
31
- uint32_t kv_size,
32
- uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33
- const int32_t n_layer = hparams.n_layer;
34
-
35
- has_shift = false;
36
- can_shift = true;
37
-
38
- LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
39
- __func__, kv_size, lm_ggml_type_name(type_k), lm_ggml_type_name(type_v), n_layer, can_shift, padding);
40
-
41
- LM_GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
42
-
43
- head = 0;
44
- size = kv_size;
45
- used = 0;
46
-
47
- this->type_k = type_k;
48
- this->type_v = type_v;
49
-
50
- cells.clear();
51
- cells.resize(kv_size);
26
+ const llama_model & model,
27
+ layer_filter_cb && filter,
28
+ lm_ggml_type type_k,
29
+ lm_ggml_type type_v,
30
+ bool v_trans,
31
+ bool offload,
32
+ uint32_t kv_size,
33
+ uint32_t n_seq_max,
34
+ uint32_t n_pad,
35
+ uint32_t n_swa,
36
+ llama_swa_type swa_type) :
37
+ model(model), hparams(model.hparams), v_trans(v_trans),
38
+ n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
39
+
40
+ LM_GGML_ASSERT(kv_size % n_pad == 0);
52
41
 
53
42
  // create a context for each buffer type
54
43
  std::map<lm_ggml_backend_buffer_type_t, lm_ggml_context *> ctx_map;
@@ -56,7 +45,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
56
45
  auto it = ctx_map.find(buft);
57
46
  if (it == ctx_map.end()) {
58
47
  lm_ggml_init_params params = {
59
- /*.mem_size =*/ size_t(2u*n_layer*lm_ggml_tensor_overhead()),
48
+ /*.mem_size =*/ size_t(2u*hparams.n_layer*lm_ggml_tensor_overhead()),
60
49
  /*.mem_buffer =*/ NULL,
61
50
  /*.no_alloc =*/ true,
62
51
  };
@@ -75,37 +64,50 @@ llama_kv_cache_unified::llama_kv_cache_unified(
75
64
  return it->second;
76
65
  };
77
66
 
78
- k_l.reserve(n_layer);
79
- v_l.reserve(n_layer);
67
+ head = 0;
68
+ size = kv_size;
69
+ used = 0;
80
70
 
81
- for (int i = 0; i < n_layer; i++) {
82
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
83
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
71
+ cells.resize(kv_size);
72
+
73
+ for (uint32_t il = 0; il < hparams.n_layer; il++) {
74
+ if (filter && !filter(il)) {
75
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
76
+ continue;
77
+ }
78
+
79
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
80
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
84
81
 
85
82
  const char * dev_name = "CPU";
86
83
 
87
84
  lm_ggml_backend_buffer_type_t buft = lm_ggml_backend_cpu_buffer_type();
88
85
 
89
86
  if (offload) {
90
- auto * dev = model.dev_layer(i);
87
+ auto * dev = model.dev_layer(il);
91
88
  buft = lm_ggml_backend_dev_buffer_type(dev);
92
89
 
93
90
  dev_name = lm_ggml_backend_dev_name(dev);
94
91
  }
95
92
 
96
- LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name);
93
+ LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
97
94
 
98
95
  lm_ggml_context * ctx = ctx_for_buft(buft);
99
96
  if (!ctx) {
100
97
  throw std::runtime_error("failed to create ggml context for kv cache");
101
98
  }
102
99
 
103
- lm_ggml_tensor * k = lm_ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
104
- lm_ggml_tensor * v = lm_ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
105
- lm_ggml_format_name(k, "cache_k_l%d", i);
106
- lm_ggml_format_name(v, "cache_v_l%d", i);
107
- k_l.push_back(k);
108
- v_l.push_back(v);
100
+ lm_ggml_tensor * k;
101
+ lm_ggml_tensor * v;
102
+
103
+ k = lm_ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
104
+ v = lm_ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
105
+
106
+ lm_ggml_format_name(k, "cache_k_l%d", il);
107
+ lm_ggml_format_name(v, "cache_v_l%d", il);
108
+
109
+ map_layer_ids[il] = layers.size();
110
+ layers.push_back({ il, k, v });
109
111
  }
110
112
 
111
113
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -117,8 +119,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
117
119
  if (!buf) {
118
120
  throw std::runtime_error("failed to allocate buffer for kv cache");
119
121
  }
120
- lm_ggml_backend_buffer_clear(buf, 0);
122
+
121
123
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf), lm_ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
124
+
125
+ lm_ggml_backend_buffer_clear(buf, 0);
122
126
  bufs.emplace_back(buf);
123
127
  }
124
128
 
@@ -126,18 +130,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
126
130
  const size_t memory_size_k = size_k_bytes();
127
131
  const size_t memory_size_v = size_v_bytes();
128
132
 
129
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
130
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
133
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
134
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
131
135
  lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
132
136
  lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
133
137
  }
134
138
  }
135
139
 
136
140
  void llama_kv_cache_unified::clear() {
137
- for (int32_t i = 0; i < (int32_t) size; ++i) {
141
+ for (uint32_t i = 0; i < size; ++i) {
138
142
  cells[i].pos = -1;
139
143
  cells[i].seq_id.clear();
140
144
  }
145
+
141
146
  head = 0;
142
147
  used = 0;
143
148
 
@@ -166,6 +171,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
166
171
  } else {
167
172
  continue;
168
173
  }
174
+
169
175
  if (cells[i].is_empty()) {
170
176
  // keep count of the number of used cells
171
177
  if (cells[i].pos >= 0) {
@@ -262,6 +268,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
262
268
  for (uint32_t i = 0; i < size; ++i) {
263
269
  if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
264
270
  has_shift = true;
271
+
265
272
  cells[i].pos += delta;
266
273
  cells[i].delta += delta;
267
274
 
@@ -314,53 +321,60 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
314
321
  }
315
322
  }
316
323
 
317
- llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
318
- llama_pos result = 0;
324
+ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
325
+ llama_pos result = std::numeric_limits<llama_pos>::max();
319
326
 
320
327
  for (uint32_t i = 0; i < size; ++i) {
321
328
  if (cells[i].has_seq_id(seq_id)) {
322
- result = std::max(result, cells[i].pos);
329
+ result = std::min(result, cells[i].pos);
323
330
  }
324
331
  }
325
332
 
333
+ if (result == std::numeric_limits<llama_pos>::max()) {
334
+ result = -1;
335
+ }
336
+
326
337
  return result;
327
338
  }
328
339
 
329
- void llama_kv_cache_unified::restore() {
330
- if (pending.ranges.empty()) {
331
- return;
332
- }
340
+ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
341
+ llama_pos result = -1;
333
342
 
334
- uint32_t new_head = size;
343
+ for (uint32_t i = 0; i < size; ++i) {
344
+ if (cells[i].has_seq_id(seq_id)) {
345
+ result = std::max(result, cells[i].pos);
346
+ }
347
+ }
335
348
 
336
- for (auto & range : pending.ranges) {
337
- for (uint32_t i = range.c0; i < range.c1; ++i) {
338
- cells[i].seq_id.clear();
349
+ return result;
350
+ }
339
351
 
340
- // keep count of the number of used cells
341
- if (cells[i].pos >= 0) {
342
- used--;
343
- }
352
+ void llama_kv_cache_unified::restore() {
353
+ for (const auto & [id, cell] : recovery.cells) {
354
+ // TODO: move to new `struct kv_cells`
355
+ const bool is_empty0 = cells[id].is_empty();
356
+ const bool is_empty1 = cell.is_empty();
344
357
 
345
- cells[i].pos = -1;
358
+ if (!is_empty0 && is_empty1) {
359
+ used--;
360
+ } else if (is_empty0 && !is_empty1) {
361
+ used++;
346
362
  }
347
363
 
348
- new_head = std::min(new_head, range.c0);
364
+ cells[id] = cell;
349
365
  }
350
366
 
351
- if (new_head != size && new_head < head) {
352
- head = new_head;
353
- }
367
+ recovery.clear();
354
368
  }
355
369
 
356
370
  void llama_kv_cache_unified::commit() {
357
- if (pending.ranges.empty()) {
358
- LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
359
- __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
371
+ if (recovery.cells.empty()) {
372
+ LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
373
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
360
374
  return;
361
375
  }
362
376
 
363
- pending.ranges.clear();
377
+ recovery.clear();
364
378
  }
365
379
 
366
380
  bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -429,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
429
443
  void llama_kv_cache_unified::defrag_sched(float thold) {
430
444
  // - do not defrag small contexts (i.e. < 2048 tokens)
431
445
  // - count the padding towards the number of used tokens
432
- const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
446
+ const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
433
447
 
434
448
  // queue defragmentation for next llama_kv_cache_update
435
449
  if (fragmentation > thold) {
@@ -441,27 +455,26 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
441
455
 
442
456
  void llama_kv_cache_unified::set_full() {
443
457
  n = size;
458
+
459
+ // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
460
+ // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
461
+ // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
462
+ // setting it to 0 is the simplest way to achieve that
463
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13359
464
+ head = 0;
444
465
  }
445
466
 
446
- llama_sbatch llama_kv_cache_unified::sbatch_init(
447
- const llama_batch & batch,
448
- bool logits_all) {
467
+ llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
449
468
  return llama_sbatch(batch, hparams.n_embd, true, logits_all);
450
469
  }
451
470
 
452
- llama_ubatch llama_kv_cache_unified::ubatch_next(
453
- llama_sbatch & sbatch,
454
- uint32_t n_ubatch,
455
- bool embd_pooled) const {
471
+ llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
456
472
  LM_GGML_UNUSED(embd_pooled);
457
473
  return sbatch.split_simple(n_ubatch);
458
474
  }
459
475
 
460
- bool llama_kv_cache_unified::find_slot(
461
- const llama_ubatch & ubatch) {
476
+ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
462
477
  const uint32_t n_tokens = ubatch.n_tokens;
463
- const uint32_t n_seqs = ubatch.n_seqs;
464
- const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
465
478
 
466
479
  // if we have enough unused cells before the current head ->
467
480
  // better to start searching from the beginning of the cache, hoping to fill it
@@ -476,6 +489,29 @@ bool llama_kv_cache_unified::find_slot(
476
489
  return false;
477
490
  }
478
491
 
492
+ //#define FIND_SLOT_DEBUG 1
493
+ #if FIND_SLOT_DEBUG
494
+ LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
495
+
496
+ // for debugging
497
+ {
498
+ std::string ss;
499
+ if (n_swa > 0) {
500
+ for (uint32_t i = 0; i < size; ++i) {
501
+ if (cells[i].pos == -1) {
502
+ ss += '.';
503
+ } else {
504
+ ss += std::to_string(*cells[i].seq_id.begin());
505
+ }
506
+ if (i%256 == 255) {
507
+ ss += '\n';
508
+ }
509
+ }
510
+ }
511
+ LLAMA_LOG_WARN("\n%s\n", ss.c_str());
512
+ }
513
+ #endif
514
+
479
515
  uint32_t n_tested = 0;
480
516
 
481
517
  while (true) {
@@ -505,60 +541,257 @@ bool llama_kv_cache_unified::find_slot(
505
541
  }
506
542
  }
507
543
 
508
- for (uint32_t s = 0; s < n_seqs; s++) {
509
- for (uint32_t i = 0; i < n_seq_tokens; ++i) {
510
- uint32_t k = s*n_seq_tokens + i;
511
- cells[head + k].pos = ubatch.pos[k];
544
+ for (uint32_t i = 0; i < n_tokens; ++i) {
545
+ // remember the original state
546
+ if (recovery.cells.find(head + i) == recovery.cells.end()) {
547
+ recovery.cells[head + i] = cells[head + i];
548
+ }
512
549
 
513
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
514
- cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
515
- }
550
+ cells[head + i].pos = ubatch.pos[i];
551
+
552
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
553
+ cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
516
554
  }
517
555
  }
518
556
 
519
557
  used += n_tokens;
520
558
 
521
- pending.ranges.push_back({head, head + n_tokens});
522
-
523
559
  // a heuristic, to avoid attending the full cache if it is not yet utilized
524
560
  // after enough generations, the benefit from this heuristic disappears
525
561
  // if we start defragmenting the cache, the benefit from this will be more important
526
- n = std::min(size, std::max(padding, LM_GGML_PAD(cell_max(), padding)));
562
+ n = std::min(size, std::max(n_pad, LM_GGML_PAD(cell_max(), n_pad)));
563
+
564
+ #ifdef FIND_SLOT_DEBUG
565
+ LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
566
+ #endif
527
567
 
528
- //printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
568
+ return true;
569
+ }
529
570
 
571
+ bool llama_kv_cache_unified::get_can_shift() const {
530
572
  return true;
531
573
  }
532
574
 
533
- int32_t llama_kv_cache_unified::get_n_tokens() const {
534
- int32_t result = 0;
575
+ uint32_t llama_kv_cache_unified::get_n() const {
576
+ return n;
577
+ }
578
+
579
+ uint32_t llama_kv_cache_unified::get_size() const {
580
+ return size;
581
+ }
582
+
583
+ lm_ggml_tensor * llama_kv_cache_unified::get_k(lm_ggml_context * ctx, int32_t il) const {
584
+ const int32_t ikv = map_layer_ids.at(il);
585
+
586
+ auto * k = layers[ikv].k;
587
+
588
+ return lm_ggml_view_3d(ctx, k,
589
+ hparams.n_embd_head_k, hparams.n_head_kv(il), n,
590
+ lm_ggml_row_size(k->type, hparams.n_embd_head_k),
591
+ lm_ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
592
+ 0);
593
+ }
535
594
 
536
- for (uint32_t i = 0; i < size; i++) {
537
- result += cells[i].seq_id.size();
595
+ lm_ggml_tensor * llama_kv_cache_unified::get_v(lm_ggml_context * ctx, int32_t il) const {
596
+ const int32_t ikv = map_layer_ids.at(il);
597
+
598
+ auto * v = layers[ikv].v;
599
+
600
+ if (!v_trans) {
601
+ // note: v->nb[1] <= v->nb[2]
602
+ return lm_ggml_view_3d(ctx, v,
603
+ hparams.n_embd_head_v, hparams.n_head_kv(il), n,
604
+ lm_ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
605
+ lm_ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
606
+ 0);
538
607
  }
539
608
 
540
- return result;
609
+ // note: v->nb[1] > v->nb[2]
610
+ return lm_ggml_view_3d(ctx, v,
611
+ n, hparams.n_head_kv(il), hparams.n_embd_head_v,
612
+ lm_ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
613
+ lm_ggml_row_size(v->type, v->ne[1]), // v->nb[2]
614
+ 0);
541
615
  }
542
616
 
543
- int32_t llama_kv_cache_unified::get_used_cells() const {
544
- return used;
617
+ lm_ggml_tensor * llama_kv_cache_unified::cpy_k(lm_ggml_context * ctx, lm_ggml_tensor * k_cur, int32_t il) const {
618
+ const int32_t ikv = map_layer_ids.at(il);
619
+
620
+ auto * k = layers[ikv].k;
621
+
622
+ const int64_t n_tokens = k_cur->ne[2];
623
+
624
+ lm_ggml_tensor * k_view = lm_ggml_view_1d(ctx, k,
625
+ n_tokens*hparams.n_embd_k_gqa(il),
626
+ lm_ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
627
+
628
+ return lm_ggml_cpy(ctx, k_cur, k_view);
545
629
  }
546
630
 
547
- bool llama_kv_cache_unified::get_can_shift() const {
548
- return can_shift;
631
+ lm_ggml_tensor * llama_kv_cache_unified::cpy_v(lm_ggml_context * ctx, lm_ggml_tensor * v_cur, int32_t il) const {
632
+ const int32_t ikv = map_layer_ids.at(il);
633
+
634
+ auto * v = layers[ikv].v;
635
+
636
+ const int64_t n_tokens = v_cur->ne[2];
637
+
638
+ v_cur = lm_ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
639
+
640
+ lm_ggml_tensor * v_view = nullptr;
641
+
642
+ if (!v_trans) {
643
+ v_view = lm_ggml_view_1d(ctx, v,
644
+ n_tokens*hparams.n_embd_v_gqa(il),
645
+ lm_ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
646
+ } else {
647
+ // note: the V cache is transposed when not using flash attention
648
+ v_view = lm_ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
649
+ (v->ne[1])*lm_ggml_element_size(v),
650
+ ( head)*lm_ggml_element_size(v));
651
+
652
+ v_cur = lm_ggml_transpose(ctx, v_cur);
653
+ }
654
+
655
+ return lm_ggml_cpy(ctx, v_cur, v_view);
549
656
  }
550
657
 
551
- llama_pos llama_kv_cache_unified::get_pos_max() const {
552
- llama_pos pos_max = -1;
553
- for (const auto & cell : cells) {
554
- pos_max = std::max(pos_max, cell.pos);
658
+ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
659
+ // no pruning is needed when the cache does not use SWA
660
+ LM_GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
661
+
662
+ int n_attended = 0;
663
+
664
+ for (uint32_t i = 0; i < size; ++i) {
665
+ const llama_pos p0 = cells[i].pos;
666
+
667
+ if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
668
+ n_attended++;
669
+ }
670
+
671
+ if (is_masked_swa(p0, pmax)) {
672
+ if (seq_id < 0) {
673
+ cells[i].seq_id.clear();
674
+ } else if (cells[i].has_seq_id(seq_id)) {
675
+ cells[i].seq_id.erase(seq_id);
676
+ } else {
677
+ continue;
678
+ }
679
+
680
+ if (cells[i].is_empty()) {
681
+ // keep count of the number of used cells
682
+ if (cells[i].pos >= 0) {
683
+ used--;
684
+ }
685
+
686
+ cells[i].pos = -1;
687
+ }
688
+ }
689
+ }
690
+
691
+ if (n_attended < std::min<int>(n_swa, pmin)) {
692
+ LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
693
+ }
694
+ }
695
+
696
+ void llama_kv_cache_unified::set_input_kq_mask(lm_ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
697
+ const int64_t n_tokens = ubatch->n_tokens;
698
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
699
+ const int64_t n_seqs = ubatch->n_seqs;
700
+
701
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(dst->buffer));
702
+ float * data = (float *) dst->data;
703
+
704
+ const int64_t n_kv = n;
705
+
706
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
707
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
708
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
709
+ // Causal mask:
710
+ // xxx-------
711
+ // xxxx------
712
+ // xxxxx-----
713
+ // Non-causal mask:
714
+ // xxxxx-----
715
+ // xxxxx-----
716
+ // xxxxx-----
717
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
718
+ for (int h = 0; h < 1; ++h) {
719
+ for (int s = 0; s < n_seqs; ++s) {
720
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
721
+
722
+ for (int j = 0; j < n_seq_tokens; ++j) {
723
+ const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
724
+
725
+ for (int i = 0; i < n_kv; ++i) {
726
+ const llama_pos p0 = cells[i].pos;
727
+
728
+ bool masked = false;
729
+
730
+ // mask the token if not the same sequence
731
+ masked = masked || (!cells[i].has_seq_id(seq_id));
732
+
733
+ // mask future tokens
734
+ masked = masked || (causal_attn && p0 > p1);
735
+
736
+ // apply SWA if any
737
+ masked = masked || (is_masked_swa(p0, p1));
738
+
739
+ float f = 0.0f;
740
+
741
+ if (masked) {
742
+ f = -INFINITY;
743
+ } else if (hparams.use_alibi) {
744
+ f = -std::abs(p0 - p1);
745
+ }
746
+
747
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
748
+ }
749
+ }
750
+ }
751
+
752
+ // mask padded tokens
753
+ if (data) {
754
+ for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
755
+ for (int j = 0; j < n_kv; ++j) {
756
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
757
+ }
758
+ }
759
+ }
555
760
  }
761
+ }
762
+
763
+ void llama_kv_cache_unified::set_input_k_shift(lm_ggml_tensor * dst) const {
764
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(dst->buffer));
765
+
766
+ int32_t * data = (int32_t *) dst->data;
767
+
768
+ for (uint32_t i = 0; i < size; ++i) {
769
+ data[i] = cells[i].delta;
770
+ }
771
+ }
556
772
 
557
- return pos_max;
773
+ void llama_kv_cache_unified::set_input_pos_bucket(lm_ggml_tensor * dst, const llama_ubatch * ubatch) const {
774
+ const int64_t n_tokens = ubatch->n_tokens;
775
+
776
+ LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(dst->buffer));
777
+ LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
778
+
779
+ int32_t * data = (int32_t *) dst->data;
780
+
781
+ const int64_t n_kv = n;
782
+
783
+ for (int h = 0; h < 1; ++h) {
784
+ for (int j = 0; j < n_tokens; ++j) {
785
+ for (int i = 0; i < n_kv; ++i) {
786
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
787
+ }
788
+ }
789
+ }
558
790
  }
559
791
 
560
792
  size_t llama_kv_cache_unified::total_size() const {
561
793
  size_t size = 0;
794
+
562
795
  for (const auto & buf : bufs) {
563
796
  size += lm_ggml_backend_buffer_get_size(buf.get());
564
797
  }
@@ -569,8 +802,8 @@ size_t llama_kv_cache_unified::total_size() const {
569
802
  size_t llama_kv_cache_unified::size_k_bytes() const {
570
803
  size_t size_k_bytes = 0;
571
804
 
572
- for (const auto & k : k_l) {
573
- size_k_bytes += lm_ggml_nbytes(k);
805
+ for (const auto & layer : layers) {
806
+ size_k_bytes += lm_ggml_nbytes(layer.k);
574
807
  }
575
808
 
576
809
  return size_k_bytes;
@@ -579,8 +812,8 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
579
812
  size_t llama_kv_cache_unified::size_v_bytes() const {
580
813
  size_t size_v_bytes = 0;
581
814
 
582
- for (const auto & v : v_l) {
583
- size_v_bytes += lm_ggml_nbytes(v);
815
+ for (const auto & layer : layers) {
816
+ size_v_bytes += lm_ggml_nbytes(layer.v);
584
817
  }
585
818
 
586
819
  return size_v_bytes;
@@ -644,13 +877,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
644
877
  LM_GGML_UNUSED(ubatch);
645
878
 
646
879
  if (k_shift) {
647
- assert(lm_ggml_backend_buffer_is_host(k_shift->buffer));
648
-
649
- int32_t * data = (int32_t *) k_shift->data;
650
-
651
- for (uint32_t i = 0; i < kv_self->size; ++i) {
652
- data[i] = kv_self->cells[i].delta;
653
- }
880
+ kv_self->set_input_k_shift(k_shift);
654
881
  }
655
882
  }
656
883
 
@@ -660,13 +887,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
660
887
  lm_ggml_cgraph * gf) const {
661
888
  auto res = std::make_unique<llm_graph_result>();
662
889
 
663
- const auto & n_layer = hparams.n_layer;
664
-
665
890
  const auto & n_embd_head_k = hparams.n_embd_head_k;
666
891
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
667
892
 
668
- const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
669
-
670
893
  //LM_GGML_ASSERT(kv_self->size == n_ctx);
671
894
 
672
895
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
@@ -674,24 +897,22 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
674
897
  inp->k_shift = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, cparams.n_ctx);
675
898
  lm_ggml_set_input(inp->k_shift);
676
899
 
677
- for (uint32_t il = 0; il < n_layer; ++il) {
900
+ for (const auto & layer : layers) {
901
+ const uint32_t il = layer.il;
902
+
678
903
  const int64_t n_head_kv = hparams.n_head_kv(il);
679
904
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
680
905
 
681
- const bool is_swa = hparams.is_swa(il);
906
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
907
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
682
908
 
683
- // note: the swa rope params could become part of the cparams in the future
684
- // if we decide to make them configurable, like the non-sliding ones
685
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
686
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
687
-
688
- lm_ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
909
+ lm_ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
689
910
 
690
911
  lm_ggml_tensor * k =
691
- lm_ggml_view_3d(ctx, k_l[il],
912
+ lm_ggml_view_3d(ctx, layer.k,
692
913
  n_embd_head_k, n_head_kv, size,
693
- lm_ggml_row_size(k_l[il]->type, n_embd_head_k),
694
- lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa),
914
+ lm_ggml_row_size(layer.k->type, n_embd_head_k),
915
+ lm_ggml_row_size(layer.k->type, n_embd_k_gqa),
695
916
  0);
696
917
 
697
918
  lm_ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
@@ -796,44 +1017,46 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
796
1017
  nm++;
797
1018
  }
798
1019
 
799
- for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
1020
+ for (const auto & layer : layers) {
1021
+ const uint32_t il = layer.il;
1022
+
800
1023
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
801
1024
  const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
802
1025
 
803
- lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx, k_l[il],
1026
+ lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx, layer.k,
804
1027
  n_embd_k_gqa, nm,
805
- lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa),
806
- lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
1028
+ lm_ggml_row_size(layer.k->type, n_embd_k_gqa),
1029
+ lm_ggml_row_size(layer.k->type, n_embd_k_gqa*i));
807
1030
 
808
- lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx, k_l[il],
1031
+ lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx, layer.k,
809
1032
  n_embd_k_gqa, nm,
810
- lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa),
811
- lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
1033
+ lm_ggml_row_size(layer.k->type, n_embd_k_gqa),
1034
+ lm_ggml_row_size(layer.k->type, n_embd_k_gqa*id));
812
1035
 
813
1036
  lm_ggml_tensor * view_v_src;
814
1037
  lm_ggml_tensor * view_v_dst;
815
1038
 
816
1039
  if (cparams.flash_attn) {
817
1040
  // NOTE: the V cache is not transposed when using flash attention
818
- view_v_src = lm_ggml_view_2d(ctx, v_l[il],
1041
+ view_v_src = lm_ggml_view_2d(ctx, layer.v,
819
1042
  n_embd_v_gqa, nm,
820
- lm_ggml_row_size(v_l[il]->type, n_embd_v_gqa),
821
- lm_ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
1043
+ lm_ggml_row_size(layer.v->type, n_embd_v_gqa),
1044
+ lm_ggml_row_size(layer.v->type, n_embd_v_gqa*i));
822
1045
 
823
- view_v_dst = lm_ggml_view_2d(ctx, v_l[il],
1046
+ view_v_dst = lm_ggml_view_2d(ctx, layer.v,
824
1047
  n_embd_v_gqa, nm,
825
- lm_ggml_row_size(v_l[il]->type, n_embd_v_gqa),
826
- lm_ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
1048
+ lm_ggml_row_size(layer.v->type, n_embd_v_gqa),
1049
+ lm_ggml_row_size(layer.v->type, n_embd_v_gqa*id));
827
1050
  } else {
828
- view_v_src = lm_ggml_view_2d(ctx, v_l[il],
1051
+ view_v_src = lm_ggml_view_2d(ctx, layer.v,
829
1052
  nm, n_embd_v_gqa,
830
- lm_ggml_row_size(v_l[il]->type, size),
831
- lm_ggml_row_size(v_l[il]->type, i));
1053
+ lm_ggml_row_size(layer.v->type, size),
1054
+ lm_ggml_row_size(layer.v->type, i));
832
1055
 
833
- view_v_dst = lm_ggml_view_2d(ctx, v_l[il],
1056
+ view_v_dst = lm_ggml_view_2d(ctx, layer.v,
834
1057
  nm, n_embd_v_gqa,
835
- lm_ggml_row_size(v_l[il]->type, size),
836
- lm_ggml_row_size(v_l[il]->type, id));
1058
+ lm_ggml_row_size(layer.v->type, size),
1059
+ lm_ggml_row_size(layer.v->type, id));
837
1060
  }
838
1061
 
839
1062
  lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx, view_k_src, view_k_dst));
@@ -850,7 +1073,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
850
1073
  }
851
1074
 
852
1075
  bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
853
- const uint32_t n_layer = hparams.n_layer;
1076
+ const uint32_t n_layer = layers.size();
854
1077
 
855
1078
  const uint32_t n_kv = cell_max();
856
1079
  const uint32_t n_used = used;
@@ -998,6 +1221,34 @@ uint32_t llama_kv_cache_unified::cell_max() const {
998
1221
  return 0;
999
1222
  }
1000
1223
 
1224
+ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1225
+ if (p0 < 0) {
1226
+ return true;
1227
+ }
1228
+
1229
+ switch (swa_type) {
1230
+ case LLAMA_SWA_TYPE_NONE:
1231
+ {
1232
+ } break;
1233
+ case LLAMA_SWA_TYPE_STANDARD:
1234
+ {
1235
+ if (p1 - p0 >= (int32_t) n_swa) {
1236
+ return true;
1237
+ }
1238
+ } break;
1239
+ case LLAMA_SWA_TYPE_CHUNKED:
1240
+ {
1241
+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1242
+
1243
+ if (p0 < pos_chunk_start) {
1244
+ return true;
1245
+ }
1246
+ } break;
1247
+ }
1248
+
1249
+ return false;
1250
+ }
1251
+
1001
1252
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1002
1253
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1003
1254
  uint32_t cell_count = 0;
@@ -1075,7 +1326,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
1075
1326
 
1076
1327
  void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1077
1328
  const uint32_t v_trans = this->v_trans ? 1 : 0;
1078
- const uint32_t n_layer = hparams.n_layer;
1329
+ const uint32_t n_layer = layers.size();
1079
1330
 
1080
1331
  io.write(&v_trans, sizeof(v_trans));
1081
1332
  io.write(&n_layer, sizeof(n_layer));
@@ -1084,56 +1335,63 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1084
1335
 
1085
1336
  // Iterate and write all the keys first, each row is a cell
1086
1337
  // Get whole range at a time
1087
- for (uint32_t il = 0; il < n_layer; ++il) {
1338
+ for (const auto & layer : layers) {
1339
+ const uint32_t il = layer.il;
1340
+
1088
1341
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1089
1342
 
1090
1343
  // Write key type
1091
- const int32_t k_type_i = (int32_t)k_l[il]->type;
1344
+ const int32_t k_type_i = (int32_t)layer.k->type;
1092
1345
  io.write(&k_type_i, sizeof(k_type_i));
1093
1346
 
1094
1347
  // Write row size of key
1095
- const uint64_t k_size_row = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1348
+ const uint64_t k_size_row = lm_ggml_row_size(layer.k->type, n_embd_k_gqa);
1096
1349
  io.write(&k_size_row, sizeof(k_size_row));
1097
1350
 
1098
1351
  // Read each range of cells of k_size length each into tmp_buf and write out
1099
1352
  for (const auto & range : cell_ranges) {
1100
1353
  const size_t range_size = range.second - range.first;
1101
1354
  const size_t buf_size = range_size * k_size_row;
1102
- io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
1355
+ io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1103
1356
  }
1104
1357
  }
1105
1358
 
1106
1359
  if (!v_trans) {
1107
- for (uint32_t il = 0; il < n_layer; ++il) {
1360
+ for (const auto & layer : layers) {
1361
+ const uint32_t il = layer.il;
1362
+
1108
1363
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1109
1364
 
1110
1365
  // Write value type
1111
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1366
+ const int32_t v_type_i = (int32_t)layer.v->type;
1112
1367
  io.write(&v_type_i, sizeof(v_type_i));
1113
1368
 
1114
1369
  // Write row size of value
1115
- const uint64_t v_size_row = lm_ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1370
+ const uint64_t v_size_row = lm_ggml_row_size(layer.v->type, n_embd_v_gqa);
1116
1371
  io.write(&v_size_row, sizeof(v_size_row));
1117
1372
 
1118
1373
  // Read each range of cells of v_size length each into tmp_buf and write out
1119
1374
  for (const auto & range : cell_ranges) {
1120
1375
  const size_t range_size = range.second - range.first;
1121
1376
  const size_t buf_size = range_size * v_size_row;
1122
- io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
1377
+ io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1123
1378
  }
1124
1379
  }
1125
1380
  } else {
1126
1381
  // When v is transposed, we also need the element size and get the element ranges from each row
1127
1382
  const uint32_t kv_size = size;
1128
- for (uint32_t il = 0; il < n_layer; ++il) {
1383
+
1384
+ for (const auto & layer : layers) {
1385
+ const uint32_t il = layer.il;
1386
+
1129
1387
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1130
1388
 
1131
1389
  // Write value type
1132
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1390
+ const int32_t v_type_i = (int32_t)layer.v->type;
1133
1391
  io.write(&v_type_i, sizeof(v_type_i));
1134
1392
 
1135
1393
  // Write element size
1136
- const uint32_t v_size_el = lm_ggml_type_size(v_l[il]->type);
1394
+ const uint32_t v_size_el = lm_ggml_type_size(layer.v->type);
1137
1395
  io.write(&v_size_el, sizeof(v_size_el));
1138
1396
 
1139
1397
  // Write GQA embedding size
@@ -1146,7 +1404,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1146
1404
  const size_t range_size = range.second - range.first;
1147
1405
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1148
1406
  const size_t buf_size = range_size * v_size_el;
1149
- io.write_tensor(v_l[il], src_offset, buf_size);
1407
+ io.write_tensor(layer.v, src_offset, buf_size);
1150
1408
  }
1151
1409
  }
1152
1410
  }
@@ -1163,8 +1421,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1163
1421
  llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1164
1422
 
1165
1423
  batch.n_tokens = cell_count;
1166
- batch.n_seq_tokens = cell_count;
1167
- batch.n_seqs = 1;
1168
1424
 
1169
1425
  for (uint32_t i = 0; i < cell_count; ++i) {
1170
1426
  llama_pos pos;
@@ -1179,13 +1435,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1179
1435
  }
1180
1436
 
1181
1437
  batch.pos[i] = pos;
1438
+ batch.n_seq_id[i] = 1;
1439
+ batch.seq_id[i] = &dest_seq_id;
1182
1440
  }
1183
- batch.n_seq_id[0] = 1;
1184
- batch.seq_id[0] = &dest_seq_id;
1441
+
1185
1442
  if (!find_slot(batch)) {
1186
1443
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1187
1444
  return false;
1188
1445
  }
1446
+
1189
1447
  commit();
1190
1448
 
1191
1449
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
@@ -1220,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1220
1478
  llama_seq_id seq_id;
1221
1479
  io.read_to(&seq_id, sizeof(seq_id));
1222
1480
 
1223
- // TODO: llama_kv_cache_unified should have a notion of max sequences
1224
- //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1225
- if (seq_id < 0) {
1226
- //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1227
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
1481
+ if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1482
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
1228
1483
  return false;
1229
1484
  }
1230
1485
 
@@ -1242,11 +1497,12 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1242
1497
  bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1243
1498
  uint32_t v_trans;
1244
1499
  uint32_t n_layer;
1500
+
1245
1501
  io.read_to(&v_trans, sizeof(v_trans));
1246
1502
  io.read_to(&n_layer, sizeof(n_layer));
1247
1503
 
1248
- if (n_layer != hparams.n_layer) {
1249
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1504
+ if (n_layer != layers.size()) {
1505
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1250
1506
  return false;
1251
1507
  }
1252
1508
  if (cell_count > size) {
@@ -1259,13 +1515,15 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1259
1515
  }
1260
1516
 
1261
1517
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1262
- for (uint32_t il = 0; il < n_layer; ++il) {
1518
+ for (const auto & layer : layers) {
1519
+ const uint32_t il = layer.il;
1520
+
1263
1521
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1264
1522
 
1265
1523
  // Read type of key
1266
1524
  int32_t k_type_i_ref;
1267
1525
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1268
- const int32_t k_type_i = (int32_t) k_l[il]->type;
1526
+ const int32_t k_type_i = (int32_t) layer.k->type;
1269
1527
  if (k_type_i != k_type_i_ref) {
1270
1528
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1271
1529
  return false;
@@ -1274,7 +1532,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1274
1532
  // Read row size of key
1275
1533
  uint64_t k_size_row_ref;
1276
1534
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1277
- const size_t k_size_row = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1535
+ const size_t k_size_row = lm_ggml_row_size(layer.k->type, n_embd_k_gqa);
1278
1536
  if (k_size_row != k_size_row_ref) {
1279
1537
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1280
1538
  return false;
@@ -1282,18 +1540,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1282
1540
 
1283
1541
  if (cell_count) {
1284
1542
  // Read and set the keys for the whole cell range
1285
- lm_ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1543
+ lm_ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1286
1544
  }
1287
1545
  }
1288
1546
 
1289
1547
  if (!this->v_trans) {
1290
- for (uint32_t il = 0; il < n_layer; ++il) {
1548
+ for (const auto & layer : layers) {
1549
+ const uint32_t il = layer.il;
1550
+
1291
1551
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1292
1552
 
1293
1553
  // Read type of value
1294
1554
  int32_t v_type_i_ref;
1295
1555
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1296
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1556
+ const int32_t v_type_i = (int32_t)layer.v->type;
1297
1557
  if (v_type_i != v_type_i_ref) {
1298
1558
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1299
1559
  return false;
@@ -1302,7 +1562,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1302
1562
  // Read row size of value
1303
1563
  uint64_t v_size_row_ref;
1304
1564
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1305
- const size_t v_size_row = lm_ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1565
+ const size_t v_size_row = lm_ggml_row_size(layer.v->type, n_embd_v_gqa);
1306
1566
  if (v_size_row != v_size_row_ref) {
1307
1567
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1308
1568
  return false;
@@ -1310,18 +1570,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1310
1570
 
1311
1571
  if (cell_count) {
1312
1572
  // Read and set the values for the whole cell range
1313
- lm_ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1573
+ lm_ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1314
1574
  }
1315
1575
  }
1316
1576
  } else {
1317
1577
  // For each layer, read the values for each cell (transposed)
1318
- for (uint32_t il = 0; il < n_layer; ++il) {
1578
+ for (const auto & layer : layers) {
1579
+ const uint32_t il = layer.il;
1580
+
1319
1581
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1320
1582
 
1321
1583
  // Read type of value
1322
1584
  int32_t v_type_i_ref;
1323
1585
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1324
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1586
+ const int32_t v_type_i = (int32_t)layer.v->type;
1325
1587
  if (v_type_i != v_type_i_ref) {
1326
1588
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1327
1589
  return false;
@@ -1330,7 +1592,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1330
1592
  // Read element size of value
1331
1593
  uint32_t v_size_el_ref;
1332
1594
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1333
- const size_t v_size_el = lm_ggml_type_size(v_l[il]->type);
1595
+ const size_t v_size_el = lm_ggml_type_size(layer.v->type);
1334
1596
  if (v_size_el != v_size_el_ref) {
1335
1597
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1336
1598
  return false;
@@ -1348,7 +1610,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1348
1610
  // For each row in the transposed matrix, read the values for the whole cell range
1349
1611
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1350
1612
  const size_t dst_offset = (head + j * size) * v_size_el;
1351
- lm_ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1613
+ lm_ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1352
1614
  }
1353
1615
  }
1354
1616
  }
@@ -1357,6 +1619,193 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1357
1619
  return true;
1358
1620
  }
1359
1621
 
1622
+ //
1623
+ // llama_kv_cache_unified_iswa
1624
+ //
1625
+
1626
+ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1627
+ const llama_model & model,
1628
+ lm_ggml_type type_k,
1629
+ lm_ggml_type type_v,
1630
+ bool v_trans,
1631
+ bool offload,
1632
+ bool swa_full,
1633
+ uint32_t kv_size,
1634
+ uint32_t n_seq_max,
1635
+ uint32_t n_batch,
1636
+ uint32_t n_pad) : hparams(model.hparams) {
1637
+ llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
1638
+ llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
1639
+
1640
+ const uint32_t size_base = kv_size;
1641
+
1642
+ uint32_t size_swa = std::min(size_base, LM_GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
1643
+
1644
+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
1645
+ if (swa_full) {
1646
+ LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
1647
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
1648
+
1649
+ size_swa = size_base;
1650
+ do_prune = false;
1651
+ }
1652
+
1653
+ LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
1654
+
1655
+ kv_base = std::make_unique<llama_kv_cache_unified>(
1656
+ model, std::move(filter_base), type_k, type_v,
1657
+ v_trans, offload, size_base, n_seq_max, n_pad,
1658
+ 0, LLAMA_SWA_TYPE_NONE);
1659
+
1660
+ LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
1661
+
1662
+ kv_swa = std::make_unique<llama_kv_cache_unified>(
1663
+ model, std::move(filter_swa), type_k, type_v,
1664
+ v_trans, offload, size_swa, n_seq_max, n_pad,
1665
+ hparams.n_swa, hparams.swa_type);
1666
+ }
1667
+
1668
+ void llama_kv_cache_unified_iswa::clear() {
1669
+ kv_base->clear();
1670
+ kv_swa ->clear();
1671
+ }
1672
+
1673
+ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1674
+ bool res = true;
1675
+
1676
+ res = res & kv_base->seq_rm(seq_id, p0, p1);
1677
+ res = res & kv_swa ->seq_rm(seq_id, p0, p1);
1678
+
1679
+ return res;
1680
+ }
1681
+
1682
+ void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1683
+ kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1684
+ kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1685
+ }
1686
+
1687
+ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
1688
+ kv_base->seq_keep(seq_id);
1689
+ kv_swa ->seq_keep(seq_id);
1690
+ }
1691
+
1692
+ void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
1693
+ kv_base->seq_add(seq_id, p0, p1, delta);
1694
+ kv_swa ->seq_add(seq_id, p0, p1, delta);
1695
+ }
1696
+
1697
+ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1698
+ kv_base->seq_div(seq_id, p0, p1, d);
1699
+ kv_swa ->seq_div(seq_id, p0, p1, d);
1700
+ }
1701
+
1702
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
1703
+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
1704
+ return kv_swa->seq_pos_min(seq_id);
1705
+ }
1706
+
1707
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
1708
+ return kv_swa->seq_pos_max(seq_id);
1709
+ }
1710
+
1711
+ void llama_kv_cache_unified_iswa::restore() {
1712
+ kv_base->restore();
1713
+ kv_swa ->restore();
1714
+ }
1715
+
1716
+ void llama_kv_cache_unified_iswa::commit() {
1717
+ kv_base->commit();
1718
+ kv_swa ->commit();
1719
+
1720
+ // slide the attention window, forgetting/pruning old tokens that are outside the window
1721
+ if (do_prune) {
1722
+ for (const auto & [seq_id, entry] : pending.pos) {
1723
+ kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
1724
+ }
1725
+
1726
+ }
1727
+
1728
+ pending.clear();
1729
+ }
1730
+
1731
+ bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
1732
+ bool res = true;
1733
+
1734
+ res = res & kv_base->update(lctx);
1735
+ res = res & kv_swa ->update(lctx);
1736
+
1737
+ return res;
1738
+ }
1739
+
1740
+ void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
1741
+ kv_base->defrag_sched(thold);
1742
+ kv_swa ->defrag_sched(thold);
1743
+ }
1744
+
1745
+ void llama_kv_cache_unified_iswa::set_full() {
1746
+ kv_base->set_full();
1747
+ kv_swa ->set_full();
1748
+ }
1749
+
1750
+ llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1751
+ pending.clear();
1752
+
1753
+ if (do_prune) {
1754
+ for (int i = 0; i < batch.n_tokens; ++i) {
1755
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1756
+ const llama_seq_id seq_id = batch.seq_id[i][s];
1757
+ const llama_pos pos = batch.pos[i];
1758
+
1759
+ if (pending.pos.find(seq_id) == pending.pos.end()) {
1760
+ pending.pos[seq_id].pmin = pos;
1761
+ pending.pos[seq_id].pmax = pos;
1762
+ } else {
1763
+ pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
1764
+ pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
1765
+ }
1766
+ }
1767
+ }
1768
+ }
1769
+
1770
+ return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1771
+ }
1772
+
1773
+ llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1774
+ LM_GGML_UNUSED(embd_pooled);
1775
+ return sbatch.split_simple(n_ubatch);
1776
+ }
1777
+
1778
+ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
1779
+ bool res = true;
1780
+
1781
+ res = res & kv_base->find_slot(batch);
1782
+ res = res & kv_swa ->find_slot(batch);
1783
+
1784
+ return res;
1785
+ }
1786
+
1787
+ bool llama_kv_cache_unified_iswa::get_can_shift() const {
1788
+ return kv_base->get_size() == kv_swa->get_size();
1789
+ }
1790
+
1791
+ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1792
+ kv_base->state_write(io, seq_id);
1793
+ kv_swa ->state_write(io, seq_id);
1794
+ }
1795
+
1796
+ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1797
+ kv_base->state_read(io, seq_id);
1798
+ kv_swa ->state_read(io, seq_id);
1799
+ }
1800
+
1801
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
1802
+ return kv_base.get();
1803
+ }
1804
+
1805
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
1806
+ return kv_swa.get();
1807
+ }
1808
+
1360
1809
  //
1361
1810
  // llama_kv_cache_recurrent
1362
1811
  //
@@ -1366,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1366
1815
  lm_ggml_type type_k,
1367
1816
  lm_ggml_type type_v,
1368
1817
  bool offload,
1369
- uint32_t kv_size) : hparams(model.hparams) {
1818
+ uint32_t kv_size,
1819
+ uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
1370
1820
  const int32_t n_layer = hparams.n_layer;
1371
1821
 
1372
- LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
1373
- __func__, kv_size, lm_ggml_type_name(type_k), lm_ggml_type_name(type_v), n_layer);
1822
+ LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
1823
+ __func__, kv_size, n_seq_max, lm_ggml_type_name(type_k), lm_ggml_type_name(type_v), n_layer);
1374
1824
 
1375
1825
  head = 0;
1376
1826
  size = kv_size;
1377
1827
  used = 0;
1378
1828
 
1379
- this->type_k = type_k;
1380
- this->type_v = type_v;
1381
-
1382
1829
  cells.clear();
1383
1830
  cells.resize(kv_size);
1384
1831
 
@@ -1676,8 +2123,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
1676
2123
  }
1677
2124
  }
1678
2125
 
2126
+ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
2127
+ llama_pos result = std::numeric_limits<llama_pos>::max();
2128
+
2129
+ for (uint32_t i = 0; i < size; ++i) {
2130
+ if (cells[i].has_seq_id(seq_id)) {
2131
+ result = std::min(result, cells[i].pos);
2132
+ }
2133
+ }
2134
+
2135
+ if (result == std::numeric_limits<llama_pos>::max()) {
2136
+ result = -1;
2137
+ }
2138
+
2139
+ return result;
2140
+ }
2141
+
1679
2142
  llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
1680
- llama_pos result = 0;
2143
+ llama_pos result = -1;
1681
2144
 
1682
2145
  for (uint32_t i = 0; i < size; ++i) {
1683
2146
  if (cells[i].has_seq_id(seq_id)) {
@@ -1700,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() {
1700
2163
  pending.ranges.clear();
1701
2164
  }
1702
2165
 
1703
- bool llama_kv_cache_recurrent::update(llama_context & lctx) {
1704
- LM_GGML_UNUSED(lctx);
2166
+ bool llama_kv_cache_recurrent::update(llama_context & ctx) {
2167
+ LM_GGML_UNUSED(ctx);
1705
2168
  return false;
1706
2169
  }
1707
2170
 
@@ -1712,6 +2175,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
1712
2175
 
1713
2176
  void llama_kv_cache_recurrent::set_full() {
1714
2177
  n = size;
2178
+ head = 0;
1715
2179
  }
1716
2180
 
1717
2181
  llama_sbatch llama_kv_cache_recurrent::sbatch_init(
@@ -1761,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot(
1761
2225
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
1762
2226
  // too big seq_id
1763
2227
  // TODO: would it be possible to resize the cache instead?
1764
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
2228
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
1765
2229
  return false;
1766
2230
  }
1767
2231
  if (j > 0) {
@@ -1904,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot(
1904
2368
  return n >= n_seqs;
1905
2369
  }
1906
2370
 
1907
- int32_t llama_kv_cache_recurrent::get_n_tokens() const {
1908
- int32_t result = 0;
1909
-
1910
- for (uint32_t i = 0; i < size; i++) {
1911
- result += cells[i].seq_id.size();
1912
- }
1913
-
1914
- return result;
1915
- }
1916
-
1917
- int32_t llama_kv_cache_recurrent::get_used_cells() const {
1918
- return used;
1919
- }
1920
-
1921
- llama_pos llama_kv_cache_recurrent::get_pos_max() const {
1922
- llama_pos pos_max = -1;
1923
- for (const auto & cell : cells) {
1924
- pos_max = std::max(pos_max, cell.pos);
1925
- }
1926
-
1927
- return pos_max;
1928
- }
1929
-
1930
2371
  bool llama_kv_cache_recurrent::get_can_shift() const {
1931
2372
  return false;
1932
2373
  }
@@ -2055,6 +2496,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
2055
2496
  io.read_to(&cell_count, sizeof(cell_count));
2056
2497
 
2057
2498
  bool res = true;
2499
+
2058
2500
  res = res && state_read_meta(io, cell_count, seq_id);
2059
2501
  res = res && state_read_data(io, cell_count);
2060
2502
 
@@ -2383,104 +2825,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2383
2825
 
2384
2826
  return true;
2385
2827
  }
2386
-
2387
- //
2388
- // kv cache view
2389
- //
2390
-
2391
- llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
2392
- llama_kv_cache_view result = {
2393
- /*.n_cells = */ 0,
2394
- /*.n_seq_max = */ n_seq_max,
2395
- /*.token_count = */ 0,
2396
- /*.used_cells = */ kv.get_used_cells(),
2397
- /*.max_contiguous = */ 0,
2398
- /*.max_contiguous_idx = */ -1,
2399
- /*.cells = */ nullptr,
2400
- /*.cells_sequences = */ nullptr,
2401
- };
2402
-
2403
- return result;
2404
- }
2405
-
2406
- void llama_kv_cache_view_free(llama_kv_cache_view * view) {
2407
- if (view->cells != nullptr) {
2408
- free(view->cells);
2409
- view->cells = nullptr;
2410
- }
2411
- if (view->cells_sequences != nullptr) {
2412
- free(view->cells_sequences);
2413
- view->cells_sequences = nullptr;
2414
- }
2415
- }
2416
-
2417
- void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
2418
- // TODO: rework this in the future, for now quick hack
2419
- const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
2420
- if (kvu == nullptr) {
2421
- LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
2422
- return;
2423
- }
2424
-
2425
- if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
2426
- view->n_cells = int32_t(kvu->size);
2427
- void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
2428
- LM_GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
2429
- view->cells = (llama_kv_cache_view_cell *)p;
2430
- p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
2431
- LM_GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
2432
- view->cells_sequences = (llama_seq_id *)p;
2433
- }
2434
-
2435
- const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
2436
- llama_kv_cache_view_cell * c_curr = view->cells;
2437
- llama_seq_id * cs_curr = view->cells_sequences;
2438
- int32_t used_cells = 0;
2439
- int32_t token_count = 0;
2440
- int32_t curr_contig_idx = -1;
2441
- uint32_t max_contig = 0;
2442
- int32_t max_contig_idx = -1;
2443
-
2444
- for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
2445
- const size_t curr_size = kv_cells[i].seq_id.size();
2446
- token_count += curr_size;
2447
- c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
2448
-
2449
- if (curr_size > 0) {
2450
- if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
2451
- max_contig = i - curr_contig_idx;
2452
- max_contig_idx = curr_contig_idx;
2453
- }
2454
- curr_contig_idx = -1;
2455
- } else if (curr_contig_idx < 0) {
2456
- curr_contig_idx = i;
2457
- }
2458
-
2459
- int seq_idx = 0;
2460
- for (const llama_seq_id it : kv_cells[i].seq_id) {
2461
- if (seq_idx >= view->n_seq_max) {
2462
- break;
2463
- }
2464
- cs_curr[seq_idx] = it;
2465
- seq_idx++;
2466
- }
2467
- if (seq_idx != 0) {
2468
- used_cells++;
2469
- }
2470
- for (; seq_idx < view->n_seq_max; seq_idx++) {
2471
- cs_curr[seq_idx] = -1;
2472
- }
2473
- }
2474
- if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
2475
- max_contig_idx = curr_contig_idx;
2476
- max_contig = kv_cells.size() - curr_contig_idx;
2477
- }
2478
- view->max_contiguous = max_contig;
2479
- view->max_contiguous_idx = max_contig_idx;
2480
- view->token_count = token_count;
2481
- view->used_cells = used_cells;
2482
- if (uint32_t(used_cells) != kvu->used) {
2483
- LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
2484
- __func__, kvu->used, used_cells);
2485
- }
2486
- }