@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
33
33
 
34
34
  GGML_ASSERT(kv_size % n_pad == 0);
35
35
 
36
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
37
+ auto n_layer_cache = hparams.n_layer;
38
+ if (model.arch == LLM_ARCH_GEMMA3N) {
39
+ n_layer_cache = 20;
40
+ }
41
+
36
42
  // create a context for each buffer type
37
43
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
38
44
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
39
45
  auto it = ctx_map.find(buft);
40
46
  if (it == ctx_map.end()) {
41
47
  ggml_init_params params = {
42
- /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
48
+ /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
43
49
  /*.mem_buffer =*/ NULL,
44
50
  /*.no_alloc =*/ true,
45
51
  };
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
62
68
 
63
69
  cells.resize(kv_size);
64
70
 
65
- for (uint32_t il = 0; il < hparams.n_layer; il++) {
71
+ for (uint32_t il = 0; il < n_layer_cache; il++) {
66
72
  if (filter && !filter(il)) {
67
73
  LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
68
74
  continue;
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
102
108
  layers.push_back({ il, k, v });
103
109
  }
104
110
 
111
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
112
+ if (model.arch == LLM_ARCH_GEMMA3N) {
113
+ LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
114
+
115
+ for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
116
+ if (filter && !filter(il)) {
117
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
118
+ continue;
119
+ }
120
+
121
+ const bool is_swa = hparams.is_swa(il);
122
+ const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
123
+
124
+ GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
125
+ map_layer_ids[il] = map_layer_ids[il_reuse];
126
+
127
+ LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
128
+ }
129
+ }
130
+
105
131
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
106
132
  for (auto it : ctx_map) {
107
133
  auto * buft = it.first;
@@ -307,18 +333,24 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
307
333
  return cells.seq_pos_max(seq_id);
308
334
  }
309
335
 
310
- llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311
- const llama_batch & batch,
336
+ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
337
+ llama_batch_allocr & balloc,
312
338
  uint32_t n_ubatch,
313
339
  bool embd_all) {
314
340
  GGML_UNUSED(embd_all);
315
341
 
316
342
  do {
317
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
343
+ balloc.split_reset();
318
344
 
319
345
  std::vector<llama_ubatch> ubatches;
320
- while (sbatch.n_tokens > 0) {
321
- ubatches.push_back(sbatch.split_simple(n_ubatch));
346
+ while (true) {
347
+ auto ubatch = balloc.split_simple(n_ubatch);
348
+
349
+ if (ubatch.n_tokens == 0) {
350
+ break;
351
+ }
352
+
353
+ ubatches.push_back(std::move(ubatch)); // NOLINT
322
354
  }
323
355
 
324
356
  auto heads = prepare(ubatches);
@@ -326,18 +358,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
326
358
  break;
327
359
  }
328
360
 
329
- return std::make_unique<llama_kv_cache_unified_state>(
330
- this, std::move(sbatch), std::move(heads), std::move(ubatches));
361
+ return std::make_unique<llama_kv_cache_unified_context>(
362
+ this, std::move(heads), std::move(ubatches));
331
363
  } while (false);
332
364
 
333
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
365
+ return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
334
366
  }
335
367
 
336
- llama_memory_state_ptr llama_kv_cache_unified::init_full() {
337
- return std::make_unique<llama_kv_cache_unified_state>(this);
368
+ llama_memory_context_ptr llama_kv_cache_unified::init_full() {
369
+ return std::make_unique<llama_kv_cache_unified_context>(this);
338
370
  }
339
371
 
340
- llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
372
+ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
341
373
  bool do_shift = get_has_shift();
342
374
 
343
375
  defrag_info dinfo;
@@ -367,7 +399,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
367
399
  }
368
400
  }
369
401
 
370
- return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
402
+ return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
371
403
  }
372
404
 
373
405
  llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -644,12 +676,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
644
676
  }
645
677
 
646
678
  void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
647
- if (debug > 0) {
648
- LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
649
- LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
650
- LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
651
- }
652
-
653
679
  // keep track of the max sequence position that we would overwrite with this ubatch
654
680
  // for non-SWA cache, this would be always empty
655
681
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -657,27 +683,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
657
683
  seq_pos_max_rm[s] = -1;
658
684
  }
659
685
 
660
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
661
- for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
662
- const uint32_t idx = s*ubatch.n_seq_tokens + j;
663
-
664
- if (!cells.is_empty(head_cur + idx)) {
665
- assert(cells.seq_count(head_cur + idx) == 1);
686
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
687
+ if (!cells.is_empty(head_cur + i)) {
688
+ assert(cells.seq_count(head_cur + i) == 1);
666
689
 
667
- const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
668
- const llama_pos pos = cells.pos_get(head_cur + idx);
690
+ const llama_seq_id seq_id = cells.seq_get(head_cur + i);
691
+ const llama_pos pos = cells.pos_get(head_cur + i);
669
692
 
670
- seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
693
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
671
694
 
672
- cells.rm(head_cur + idx);
673
- }
695
+ cells.rm(head_cur + i);
696
+ }
674
697
 
675
- cells.pos_set(head_cur + idx, ubatch.pos[idx]);
698
+ cells.pos_set(head_cur + i, ubatch.pos[i]);
676
699
 
677
- // TODO: fix indexing [UBATCH_IDX]
678
- for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
679
- cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
680
- }
700
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
701
+ cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
681
702
  }
682
703
  }
683
704
 
@@ -696,6 +717,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
696
717
  seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
697
718
  }
698
719
  }
720
+
699
721
  // move the head at the end of the slot
700
722
  head = head_cur + ubatch.n_tokens;
701
723
  }
@@ -792,9 +814,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
792
814
  }
793
815
 
794
816
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
795
- const uint32_t n_tokens = ubatch->n_tokens;
796
- const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
797
- const uint32_t n_seqs = ubatch->n_seqs;
817
+ const uint32_t n_tokens = ubatch->n_tokens;
798
818
 
799
819
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
800
820
  float * data = (float *) dst->data;
@@ -814,52 +834,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
814
834
  // xxxxx-----
815
835
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
816
836
  for (uint32_t h = 0; h < 1; ++h) {
817
- for (uint32_t s = 0; s < n_seqs; ++s) {
818
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
819
-
820
- for (uint32_t j = 0; j < n_seq_tokens; ++j) {
821
- const uint32_t idx = s*n_seq_tokens + j;
837
+ for (uint32_t i = 0; i < n_tokens; ++i) {
838
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
822
839
 
823
- const llama_pos p1 = ubatch->pos[idx];
840
+ const llama_pos p1 = ubatch->pos[i];
824
841
 
825
- for (uint32_t i = 0; i < n_kv; ++i) {
826
- float f = 0.0f;
842
+ for (uint32_t j = 0; j < n_kv; ++j) {
843
+ float f = 0.0f;
827
844
 
828
- bool masked = false;
829
-
830
- if (cells.is_empty(i)) {
831
- masked = true;
832
- } else {
833
- const llama_pos p0 = cells.pos_get(i);
845
+ bool masked = false;
834
846
 
835
- // mask the token if not the same sequence
836
- masked = masked || (!cells.seq_has(i, seq_id));
847
+ if (cells.is_empty(j)) {
848
+ masked = true;
849
+ } else {
850
+ const llama_pos p0 = cells.pos_get(j);
837
851
 
838
- // mask future tokens
839
- masked = masked || (causal_attn && p0 > p1);
852
+ // mask the token if not the same sequence
853
+ masked = masked || (!cells.seq_has(j, seq_id));
840
854
 
841
- // apply SWA if any
842
- masked = masked || (is_masked_swa(p0, p1));
855
+ // mask future tokens
856
+ masked = masked || (causal_attn && p0 > p1);
843
857
 
844
- if (!masked && hparams.use_alibi) {
845
- f = -std::abs(p0 - p1);
846
- }
847
- }
858
+ // apply SWA if any
859
+ masked = masked || (is_masked_swa(p0, p1));
848
860
 
849
- if (masked) {
850
- f = -INFINITY;
861
+ if (!masked && hparams.use_alibi) {
862
+ f = -std::abs(p0 - p1);
851
863
  }
864
+ }
852
865
 
853
- data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
866
+ if (masked) {
867
+ f = -INFINITY;
854
868
  }
869
+
870
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
855
871
  }
856
872
  }
857
873
 
858
874
  // mask padded tokens
859
875
  if (data) {
860
- for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
861
- for (uint32_t i = 0; i < n_kv; ++i) {
862
- data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
876
+ for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
877
+ for (uint32_t j = 0; j < n_kv; ++j) {
878
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
863
879
  }
864
880
  }
865
881
  }
@@ -887,12 +903,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
887
903
  const int32_t n_kv = dst->ne[0];
888
904
 
889
905
  for (int h = 0; h < 1; ++h) {
890
- for (int j = 0; j < n_tokens; ++j) {
891
- for (int i = 0; i < n_kv; ++i) {
906
+ for (int i = 0; i < n_tokens; ++i) {
907
+ for (int j = 0; j < n_kv; ++j) {
892
908
  // the position when the cells is empty is irrelevant - it will be masked out later in the attention
893
- const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
909
+ const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
894
910
 
895
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
911
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
896
912
  }
897
913
  }
898
914
  }
@@ -1509,12 +1525,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1509
1525
 
1510
1526
  seq_rm(dest_seq_id, -1, -1);
1511
1527
 
1512
- llama_sbatch sbatch;
1513
- llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1528
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
1514
1529
 
1515
- ubatch.n_tokens = cell_count;
1516
- ubatch.n_seq_tokens = cell_count;
1517
- ubatch.n_seqs = 1;
1530
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1518
1531
 
1519
1532
  for (uint32_t i = 0; i < cell_count; ++i) {
1520
1533
  llama_pos pos;
@@ -1723,18 +1736,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1723
1736
  }
1724
1737
 
1725
1738
  //
1726
- // llama_kv_cache_unified_state
1739
+ // llama_kv_cache_unified_context
1727
1740
  //
1728
1741
 
1729
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1742
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
1730
1743
 
1731
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1744
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1732
1745
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1733
1746
  n_kv = kv->get_size();
1734
1747
  head = 0;
1735
1748
  }
1736
1749
 
1737
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1750
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1738
1751
  llama_kv_cache_unified * kv,
1739
1752
  llama_context * lctx,
1740
1753
  bool do_shift,
@@ -1744,16 +1757,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1744
1757
  }
1745
1758
  }
1746
1759
 
1747
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1760
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1748
1761
  llama_kv_cache_unified * kv,
1749
- llama_sbatch sbatch,
1750
1762
  llama_kv_cache_unified::ubatch_heads heads,
1751
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1763
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1752
1764
  }
1753
1765
 
1754
- llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1766
+ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
1755
1767
 
1756
- bool llama_kv_cache_unified_state::next() {
1768
+ bool llama_kv_cache_unified_context::next() {
1757
1769
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1758
1770
 
1759
1771
  if (++i_next >= ubatches.size()) {
@@ -1763,7 +1775,7 @@ bool llama_kv_cache_unified_state::next() {
1763
1775
  return true;
1764
1776
  }
1765
1777
 
1766
- bool llama_kv_cache_unified_state::apply() {
1778
+ bool llama_kv_cache_unified_context::apply() {
1767
1779
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1768
1780
 
1769
1781
  // no ubatches -> this is a KV cache update
@@ -1781,51 +1793,45 @@ bool llama_kv_cache_unified_state::apply() {
1781
1793
  return true;
1782
1794
  }
1783
1795
 
1784
- std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
1785
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1786
-
1787
- return sbatch.out_ids;
1788
- }
1789
-
1790
- llama_memory_status llama_kv_cache_unified_state::get_status() const {
1796
+ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1791
1797
  return status;
1792
1798
  }
1793
1799
 
1794
- const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
1800
+ const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1795
1801
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1796
1802
 
1797
1803
  return ubatches[i_next];
1798
1804
  }
1799
1805
 
1800
- uint32_t llama_kv_cache_unified_state::get_n_kv() const {
1806
+ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
1801
1807
  return n_kv;
1802
1808
  }
1803
1809
 
1804
- ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
1810
+ ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
1805
1811
  return kv->get_k(ctx, il, n_kv);
1806
1812
  }
1807
1813
 
1808
- ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
1814
+ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
1809
1815
  return kv->get_v(ctx, il, n_kv);
1810
1816
  }
1811
1817
 
1812
- ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1818
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1813
1819
  return kv->cpy_k(ctx, k_cur, il, head);
1814
1820
  }
1815
1821
 
1816
- ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1822
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1817
1823
  return kv->cpy_v(ctx, v_cur, il, head);
1818
1824
  }
1819
1825
 
1820
- void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
1826
+ void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1821
1827
  kv->set_input_k_shift(dst);
1822
1828
  }
1823
1829
 
1824
- void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1830
+ void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1825
1831
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1826
1832
  }
1827
1833
 
1828
- void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1834
+ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1829
1835
  kv->set_input_pos_bucket(dst, ubatch);
1830
1836
  }
1831
1837
 
@@ -56,14 +56,14 @@ public:
56
56
  // llama_memory_i
57
57
  //
58
58
 
59
- llama_memory_state_ptr init_batch(
60
- const llama_batch & batch,
59
+ llama_memory_context_ptr init_batch(
60
+ llama_batch_allocr & balloc,
61
61
  uint32_t n_ubatch,
62
62
  bool embd_all) override;
63
63
 
64
- llama_memory_state_ptr init_full() override;
64
+ llama_memory_context_ptr init_full() override;
65
65
 
66
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
66
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
67
67
 
68
68
  bool get_can_shift() const override;
69
69
 
@@ -208,49 +208,46 @@ private:
208
208
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
209
209
  };
210
210
 
211
- class llama_kv_cache_unified_state : public llama_memory_state_i {
211
+ class llama_kv_cache_unified_context : public llama_memory_context_i {
212
212
  public:
213
213
  // some shorthands
214
214
  using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
215
  using defrag_info = llama_kv_cache_unified::defrag_info;
216
216
 
217
217
  // used for errors
218
- llama_kv_cache_unified_state(llama_memory_status status);
218
+ llama_kv_cache_unified_context(llama_memory_status status);
219
219
 
220
- // used to create a full-cache state
221
- llama_kv_cache_unified_state(
220
+ // used to create a full-cache context
221
+ llama_kv_cache_unified_context(
222
222
  llama_kv_cache_unified * kv);
223
223
 
224
- // used to create an update state
225
- llama_kv_cache_unified_state(
224
+ // used to create an update context
225
+ llama_kv_cache_unified_context(
226
226
  llama_kv_cache_unified * kv,
227
227
  llama_context * lctx,
228
228
  bool do_shift,
229
229
  defrag_info dinfo);
230
230
 
231
- // used to create a decode state from a batch
232
- llama_kv_cache_unified_state(
231
+ // used to create a batch procesing context from a batch
232
+ llama_kv_cache_unified_context(
233
233
  llama_kv_cache_unified * kv,
234
- llama_sbatch sbatch,
235
234
  ubatch_heads heads,
236
235
  std::vector<llama_ubatch> ubatches);
237
236
 
238
- virtual ~llama_kv_cache_unified_state();
237
+ virtual ~llama_kv_cache_unified_context();
239
238
 
240
239
  //
241
- // llama_memory_state_i
240
+ // llama_memory_context_i
242
241
  //
243
242
 
244
243
  bool next() override;
245
244
  bool apply() override;
246
245
 
247
- std::vector<int64_t> & out_ids() override;
248
-
249
246
  llama_memory_status get_status() const override;
250
247
  const llama_ubatch & get_ubatch() const override;
251
248
 
252
249
  //
253
- // llama_kv_cache_unified_state specific API
250
+ // llama_kv_cache_unified_context specific API
254
251
  //
255
252
 
256
253
  uint32_t get_n_kv() const;
@@ -275,7 +272,7 @@ private:
275
272
  llama_context * lctx;
276
273
 
277
274
  //
278
- // update state
275
+ // update context
279
276
  //
280
277
 
281
278
  bool do_shift = false;
@@ -283,11 +280,9 @@ private:
283
280
  defrag_info dinfo;
284
281
 
285
282
  //
286
- // batch processing state
283
+ // batch processing context
287
284
  //
288
285
 
289
- llama_sbatch sbatch;
290
-
291
286
  // the index of the next ubatch to process
292
287
  size_t i_next = 0;
293
288
 
@@ -7,6 +7,7 @@
7
7
  #include <cassert>
8
8
  #include <vector>
9
9
  #include <set>
10
+ #include <map>
10
11
 
11
12
  // meta information about KV cells that can be part of multiple sequences at the same time
12
13
  // TODO: add unit tests
@@ -164,7 +165,7 @@ public:
164
165
  assert(seq_id >= 0);
165
166
 
166
167
  seq[i].reset(seq_id);
167
- seq_pos[seq_id].erase(pos[i]);
168
+ seq_pos_dec(seq_id, pos[i]);
168
169
 
169
170
  if (seq[i].none()) {
170
171
  pos[i] = -1;
@@ -187,7 +188,7 @@ public:
187
188
  seq[i].reset();
188
189
 
189
190
  seq[i].set(seq_id);
190
- seq_pos[seq_id].insert(pos[i]);
191
+ seq_pos_inc(seq_id, pos[i]);
191
192
 
192
193
  return false;
193
194
  }
@@ -232,7 +233,7 @@ public:
232
233
  assert(!seq[i].test(seq_id));
233
234
 
234
235
  seq[i].set(seq_id);
235
- seq_pos[seq_id].insert(pos[i]);
236
+ seq_pos_inc(seq_id, pos[i]);
236
237
  }
237
238
 
238
239
  // return the sequence id of this cell
@@ -259,7 +260,9 @@ public:
259
260
  return -1;
260
261
  }
261
262
 
262
- return *seq_pos[seq_id].begin();
263
+ assert(seq_pos[seq_id].begin()->second > 0);
264
+
265
+ return seq_pos[seq_id].begin()->first;
263
266
  }
264
267
 
265
268
  // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ public:
272
275
  return -1;
273
276
  }
274
277
 
275
- return *seq_pos[seq_id].rbegin();
278
+ assert(seq_pos[seq_id].rbegin()->second > 0);
279
+
280
+ return seq_pos[seq_id].rbegin()->first;
276
281
  }
277
282
 
278
283
  // note: call only if the cell is not empty
@@ -384,22 +389,41 @@ private:
384
389
  //
385
390
  std::vector<llama_pos> shift;
386
391
 
387
- using bits_t = std::bitset<LLAMA_MAX_SEQ>;
392
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
388
393
 
389
394
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
- std::vector<bits_t> seq;
395
+ std::vector<seq_set_t> seq;
391
396
 
392
- // the set seq_pos[s] tells us which positions are currently present for sequence s
397
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398
+ // if the position p is not present, seq_pos[s][p] is not set
393
399
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
400
+ //
401
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402
+ // - during performing a cache reuse via (rm + add)
403
+ // - some vision models have input embeddings with repeating positions
404
+ //
405
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
395
406
 
396
407
  // helper functions for updating `seq_pos`, once cell at a time:
397
408
 
409
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
410
+ auto it = seq_pos[s].find(p);
411
+ assert(it != seq_pos[s].end());
412
+
413
+ if (--it->second == 0) {
414
+ seq_pos[s].erase(it);
415
+ }
416
+ }
417
+
418
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
419
+ seq_pos[s][p]++;
420
+ }
421
+
398
422
  // remove cell i
399
423
  void seq_pos_rm(uint32_t i) {
400
424
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401
425
  if (seq[i].test(s)) {
402
- seq_pos[s].erase(pos[i]);
426
+ seq_pos_dec(s, pos[i]);
403
427
  }
404
428
  }
405
429
  }
@@ -408,7 +432,7 @@ private:
408
432
  void seq_pos_add(uint32_t i) {
409
433
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410
434
  if (seq[i].test(s)) {
411
- seq_pos[s].insert(pos[i]);
435
+ seq_pos_inc(s, pos[i]);
412
436
  }
413
437
  }
414
438
  }