@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -32,7 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid(
32
32
  mem_attn(new llama_kv_cache_unified(
33
33
  model,
34
34
  filter_attn == nullptr ?
35
- [&](int32_t il) { return !model.hparams.is_recurrent(il); }
35
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
36
36
  : filter_attn,
37
37
  type_k,
38
38
  type_v,
@@ -47,7 +47,7 @@ llama_memory_hybrid::llama_memory_hybrid(
47
47
  mem_recr(new llama_memory_recurrent(
48
48
  model,
49
49
  filter_recr == nullptr ?
50
- [&](int32_t il) { return model.hparams.is_recurrent(il); }
50
+ [&](int32_t il) { return hparams.is_recurrent(il); }
51
51
  : filter_recr,
52
52
  type_r,
53
53
  type_s,
@@ -56,50 +56,57 @@ llama_memory_hybrid::llama_memory_hybrid(
56
56
  n_seq_max
57
57
  )) {}
58
58
 
59
- llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
59
+ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
60
+ do {
61
+ balloc.split_reset();
60
62
 
61
- // since this includes a recurrent cache, we cannot use split_simple
62
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
63
+ // follow the recurrent pattern for creating the ubatch splits
64
+ std::vector<llama_ubatch> ubatches;
63
65
 
64
- // follow the recurrent pattern for creating the ubatch splits
65
- std::vector<llama_ubatch> ubatches;
66
- while (sbatch.n_tokens > 0) {
67
- llama_ubatch ubatch;
66
+ while (true) {
67
+ llama_ubatch ubatch;
68
68
 
69
- if (embd_pooled) {
70
- // Pooled embeddings cannot be split across ubatches (yet)
71
- ubatch = sbatch.split_seq(n_ubatch);
72
- } else {
73
- ubatch = sbatch.split_equal(n_ubatch);
69
+ if (embd_all) {
70
+ // if all tokens are output, split by sequence
71
+ ubatch = balloc.split_seq(n_ubatch);
72
+ } else {
73
+ ubatch = balloc.split_equal(n_ubatch);
74
+ }
75
+
76
+ if (ubatch.n_tokens == 0) {
77
+ break;
78
+ }
79
+
80
+ ubatches.push_back(std::move(ubatch)); // NOLINT
74
81
  }
75
82
 
76
- ubatches.push_back(ubatch);
77
- }
83
+ // prepare the recurrent batches first
84
+ if (!mem_recr->prepare(ubatches)) {
85
+ // TODO: will the recurrent cache be in an undefined context at this point?
86
+ LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
87
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
88
+ }
78
89
 
79
- // prepare the recurrent batches first
80
- if (!mem_recr->prepare(ubatches)) {
81
- // TODO: will the recurrent cache be in an undefined state at this point?
82
- LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
83
- return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
84
- }
90
+ // prepare the attention cache
91
+ auto heads_attn = mem_attn->prepare(ubatches);
92
+ if (heads_attn.empty()) {
93
+ LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
94
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
+ }
85
96
 
86
- // prepare the attention cache
87
- auto heads_attn = mem_attn->prepare(ubatches);
88
- if (heads_attn.empty()) {
89
- LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
90
- return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
91
- }
97
+ return std::make_unique<llama_memory_hybrid_context>(
98
+ this, std::move(heads_attn), std::move(ubatches));
99
+ } while(false);
92
100
 
93
- return std::make_unique<llama_memory_hybrid_state>(
94
- this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
101
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
102
  }
96
103
 
97
- llama_memory_state_ptr llama_memory_hybrid::init_full() {
98
- return std::make_unique<llama_memory_hybrid_state>(this);
104
+ llama_memory_context_ptr llama_memory_hybrid::init_full() {
105
+ return std::make_unique<llama_memory_hybrid_context>(this);
99
106
  }
100
107
 
101
- llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
102
- return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
108
+ llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
109
+ return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
103
110
  }
104
111
 
105
112
  bool llama_memory_hybrid::get_can_shift() const {
@@ -169,41 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
169
176
  return mem_recr.get();
170
177
  }
171
178
 
172
- llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
179
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
173
180
 
174
- llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
175
- state_attn(mem->get_mem_attn()->init_full()),
176
- state_recr(mem->get_mem_recr()->init_full()),
177
- status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
181
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
182
+ ctx_attn(mem->get_mem_attn()->init_full()),
183
+ ctx_recr(mem->get_mem_recr()->init_full()),
184
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
178
185
  }
179
186
 
180
- llama_memory_hybrid_state::llama_memory_hybrid_state(
187
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
181
188
  llama_memory_hybrid * mem,
182
189
  llama_context * lctx,
183
190
  bool optimize) :
184
- state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
185
- state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
186
- status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
191
+ ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
192
+ ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
193
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
187
194
  }
188
195
 
189
- llama_memory_hybrid_state::llama_memory_hybrid_state(
196
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
190
197
  llama_memory_hybrid * mem,
191
- llama_sbatch sbatch,
192
198
  std::vector<uint32_t> heads_attn,
193
199
  std::vector<llama_ubatch> ubatches) :
194
- sbatch(std::move(sbatch)),
195
200
  ubatches(std::move(ubatches)),
196
201
  // note: here we copy the ubatches. not sure if this is ideal
197
- state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
198
- state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)),
199
- status(LLAMA_MEMORY_STATUS_SUCCESS) {
202
+ ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
203
+ ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
204
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
200
205
  }
201
206
 
202
- bool llama_memory_hybrid_state::next() {
207
+ bool llama_memory_hybrid_context::next() {
203
208
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
204
209
 
205
- state_attn->next();
206
- state_recr->next();
210
+ ctx_attn->next();
211
+ ctx_recr->next();
207
212
 
208
213
  if (++i_next >= ubatches.size()) {
209
214
  return false;
@@ -212,36 +217,30 @@ bool llama_memory_hybrid_state::next() {
212
217
  return true;
213
218
  }
214
219
 
215
- bool llama_memory_hybrid_state::apply() {
220
+ bool llama_memory_hybrid_context::apply() {
216
221
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
217
222
 
218
223
  bool res = true;
219
224
 
220
- res = res & state_attn->apply();
221
- res = res & state_recr->apply();
225
+ res = res & ctx_attn->apply();
226
+ res = res & ctx_recr->apply();
222
227
 
223
228
  return res;
224
229
  }
225
230
 
226
- std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
227
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
228
-
229
- return sbatch.out_ids;
230
- }
231
-
232
- llama_memory_status llama_memory_hybrid_state::get_status() const {
231
+ llama_memory_status llama_memory_hybrid_context::get_status() const {
233
232
  return status;
234
233
  }
235
234
 
236
- const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
235
+ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
237
236
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
238
237
  return ubatches[i_next];
239
238
  }
240
239
 
241
- const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
242
- return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
240
+ const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
241
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
243
242
  }
244
243
 
245
- const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
246
- return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
244
+ const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
245
+ return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
247
246
  }
@@ -49,14 +49,14 @@ public:
49
49
  // llama_memory_i
50
50
  //
51
51
 
52
- llama_memory_state_ptr init_batch(
53
- const llama_batch & batch,
52
+ llama_memory_context_ptr init_batch(
53
+ llama_batch_allocr & balloc,
54
54
  uint32_t n_ubatch,
55
- bool embd_pooled) override;
55
+ bool embd_all) override;
56
56
 
57
- llama_memory_state_ptr init_full() override;
57
+ llama_memory_context_ptr init_full() override;
58
58
 
59
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
59
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
60
60
 
61
61
  bool get_can_shift() const override;
62
62
 
@@ -90,54 +90,49 @@ private:
90
90
  const std::unique_ptr<llama_memory_recurrent> mem_recr;
91
91
  };
92
92
 
93
- class llama_memory_hybrid_state : public llama_memory_state_i {
93
+ class llama_memory_hybrid_context : public llama_memory_context_i {
94
94
  public:
95
95
  // init failure
96
- explicit llama_memory_hybrid_state(llama_memory_status status);
96
+ explicit llama_memory_hybrid_context(llama_memory_status status);
97
97
 
98
98
  // init full
99
- explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
99
+ explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
100
100
 
101
101
  // init update
102
- explicit llama_memory_hybrid_state(
102
+ explicit llama_memory_hybrid_context(
103
103
  llama_memory_hybrid * mem,
104
104
  llama_context * lctx,
105
105
  bool optimize);
106
106
 
107
107
  // init success
108
- llama_memory_hybrid_state(
108
+ llama_memory_hybrid_context(
109
109
  llama_memory_hybrid * mem,
110
- llama_sbatch sbatch,
111
110
  std::vector<uint32_t> heads_attn,
112
111
  std::vector<llama_ubatch> ubatches);
113
112
 
114
- ~llama_memory_hybrid_state() = default;
113
+ ~llama_memory_hybrid_context() = default;
115
114
 
116
115
  bool next() override;
117
116
  bool apply() override;
118
117
 
119
- std::vector<int64_t> & out_ids() override;
120
-
121
118
  llama_memory_status get_status() const override;
122
119
  const llama_ubatch & get_ubatch() const override;
123
120
 
124
121
  //
125
- // llama_memory_hybrid_state
122
+ // llama_memory_hybrid_context
126
123
  //
127
124
 
128
- const llama_kv_cache_unified_state * get_state_attn() const;
129
- const llama_memory_recurrent_state * get_state_recr() const;
125
+ const llama_kv_cache_unified_context * get_attn() const;
126
+ const llama_memory_recurrent_context * get_recr() const;
130
127
 
131
128
  private:
132
- llama_sbatch sbatch;
133
-
134
129
  // the index of the next ubatch to process
135
130
  size_t i_next = 0;
136
131
 
137
132
  std::vector<llama_ubatch> ubatches;
138
133
 
139
- const llama_memory_state_ptr state_attn;
140
- const llama_memory_state_ptr state_recr;
134
+ const llama_memory_context_ptr ctx_attn;
135
+ const llama_memory_context_ptr ctx_recr;
141
136
 
142
137
  const llama_memory_status status;
143
138
  };
@@ -362,40 +362,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
362
362
  return result;
363
363
  }
364
364
 
365
- llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
366
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
365
+ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
366
+ do {
367
+ balloc.split_reset();
367
368
 
368
- std::vector<llama_ubatch> ubatches;
369
+ std::vector<llama_ubatch> ubatches;
370
+ while (true) {
371
+ llama_ubatch ubatch;
369
372
 
370
- while (sbatch.n_tokens > 0) {
371
- llama_ubatch ubatch;
373
+ if (embd_all) {
374
+ // if all tokens are output, split by sequence
375
+ ubatch = balloc.split_seq(n_ubatch);
376
+ } else {
377
+ ubatch = balloc.split_equal(n_ubatch);
378
+ }
372
379
 
373
- if (embd_all) {
374
- // if all tokens are output, split by sequence
375
- ubatch = sbatch.split_seq(n_ubatch);
376
- } else {
377
- ubatch = sbatch.split_equal(n_ubatch);
380
+ if (ubatch.n_tokens == 0) {
381
+ break;
382
+ }
383
+
384
+ ubatches.push_back(std::move(ubatch)); // NOLINT
378
385
  }
379
386
 
380
- ubatches.push_back(ubatch);
381
- }
387
+ if (!prepare(ubatches)) {
388
+ break;
389
+ }
382
390
 
383
- if (!prepare(ubatches)) {
384
- return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
385
- }
391
+ return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
392
+ } while (false);
386
393
 
387
- return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
394
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
388
395
  }
389
396
 
390
- llama_memory_state_ptr llama_memory_recurrent::init_full() {
391
- return std::make_unique<llama_memory_recurrent_state>(this);
397
+ llama_memory_context_ptr llama_memory_recurrent::init_full() {
398
+ return std::make_unique<llama_memory_recurrent_context>(this);
392
399
  }
393
400
 
394
- llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
401
+ llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
395
402
  GGML_UNUSED(lctx);
396
403
  GGML_UNUSED(optimize);
397
404
 
398
- return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
405
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
399
406
  }
400
407
 
401
408
  bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -423,9 +430,8 @@ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches)
423
430
  }
424
431
 
425
432
  bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
426
- const uint32_t n_seqs = ubatch.n_seqs;
427
-
428
433
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
434
+ const uint32_t n_seqs = ubatch.n_seqs;
429
435
 
430
436
  // if we have enough unused cells before the current head ->
431
437
  // better to start searching from the beginning of the cache, hoping to fill it
@@ -445,9 +451,11 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
445
451
 
446
452
  // everything should fit if all seq_ids are smaller than the max
447
453
  for (uint32_t s = 0; s < n_seqs; ++s) {
448
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
454
+ const uint32_t i = s*n_seq_tokens; // first token of sequence set s
455
+ const uint32_t n_seq_id = ubatch.n_seq_id[i];
456
+
449
457
  for (uint32_t j = 0; j < n_seq_id; ++j) {
450
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
458
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
451
459
 
452
460
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
453
461
  // too big seq_id
@@ -506,7 +514,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
506
514
 
507
515
  // find usable cell range
508
516
  for (uint32_t s = 0; s < n_seqs; ++s) {
509
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
517
+ const uint32_t i = s*n_seq_tokens;
518
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
510
519
  auto & seq_meta = cells[seq_id];
511
520
  bool has_cell = false;
512
521
  if (seq_meta.tail >= 0) {
@@ -530,7 +539,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
530
539
  seq_meta.tail = next_empty_cell;
531
540
  // find next empty cell
532
541
  if (s + 1 < n_seqs) {
533
- for (uint32_t i = 0; i < size; ++i) {
542
+ for (uint32_t j = 0; j < size; ++j) {
534
543
  next_empty_cell += 1;
535
544
  if (next_empty_cell >= size) { next_empty_cell -= size; }
536
545
  auto & cell = cells[next_empty_cell];
@@ -544,8 +553,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
544
553
 
545
554
  // gather and re-order
546
555
  for (uint32_t s = 0; s < n_seqs; ++s) {
556
+ const uint32_t i = s*n_seq_tokens;
547
557
  const int32_t dst_id = s + min;
548
- const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
558
+ const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
549
559
  if (dst_id != src_id) {
550
560
  auto & dst_cell = cells[dst_id];
551
561
  auto & src_cell = cells[src_id];
@@ -555,8 +565,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
555
565
  std::swap(dst_cell.seq_id, src_cell.seq_id);
556
566
 
557
567
  // swap tails
558
- for (uint32_t i = 0; i < size; ++i) {
559
- int32_t & tail = cells[i].tail;
568
+ for (uint32_t j = 0; j < size; ++j) {
569
+ int32_t & tail = cells[j].tail;
560
570
  if (tail == src_id) {
561
571
  tail = dst_id;
562
572
  } else if (tail == dst_id) {
@@ -568,7 +578,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
568
578
 
569
579
  // update the pos of the used seqs
570
580
  for (uint32_t s = 0; s < n_seqs; ++s) {
571
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
581
+ const uint32_t i = s*n_seq_tokens;
582
+ const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
572
583
  const int32_t cell_id = s + min;
573
584
  auto & cell = cells[cell_id];
574
585
 
@@ -576,12 +587,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
576
587
  // What should happen when the pos backtracks or skips a value?
577
588
  // Clearing the state mid-batch would require special-casing which isn't done.
578
589
  LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
579
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
590
+ __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
580
591
  }
581
592
  cell.pos = last_pos;
582
593
  cell.seq_id.clear();
583
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
584
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
594
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
595
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
585
596
  cell.seq_id.insert(seq_id);
586
597
  cells[seq_id].tail = cell_id;
587
598
  }
@@ -827,12 +838,9 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
827
838
 
828
839
  seq_rm(dest_seq_id, -1, -1);
829
840
 
830
- llama_sbatch sbatch;
831
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
841
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
832
842
 
833
- batch.n_tokens = cell_count;
834
- batch.n_seq_tokens = cell_count;
835
- batch.n_seqs = 1;
843
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
836
844
 
837
845
  for (uint32_t i = 0; i < cell_count; ++i) {
838
846
  llama_pos pos;
@@ -846,12 +854,12 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
846
854
  return false;
847
855
  }
848
856
 
849
- batch.pos[i] = pos;
857
+ ubatch.pos[i] = pos;
850
858
  }
851
- batch.n_seq_id[0] = 1;
852
- batch.seq_id[0] = &dest_seq_id;
859
+ ubatch.n_seq_id[0] = 1;
860
+ ubatch.seq_id[0] = &dest_seq_id;
853
861
 
854
- if (!find_slot(batch)) {
862
+ if (!find_slot(ubatch)) {
855
863
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
856
864
  return false;
857
865
  }
@@ -859,8 +867,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
859
867
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
860
868
  // Assume that this is one contiguous block of cells
861
869
  GGML_ASSERT(head + cell_count <= size);
862
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
863
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
870
+ GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
871
+ GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
864
872
  GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
865
873
  GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
866
874
  } else {
@@ -1037,23 +1045,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
1037
1045
  }
1038
1046
 
1039
1047
  //
1040
- // llama_memory_recurrent_state
1048
+ // llama_memory_recurrent_context
1041
1049
  //
1042
1050
 
1043
- llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
1051
+ llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
1044
1052
 
1045
- llama_memory_recurrent_state::llama_memory_recurrent_state(
1053
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1046
1054
  llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1047
1055
  }
1048
1056
 
1049
- llama_memory_recurrent_state::llama_memory_recurrent_state(
1057
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1050
1058
  llama_memory_recurrent * mem,
1051
- llama_sbatch sbatch,
1052
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
1059
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
1053
1060
 
1054
- llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
1061
+ llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
1055
1062
 
1056
- bool llama_memory_recurrent_state::next() {
1063
+ bool llama_memory_recurrent_context::next() {
1057
1064
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1058
1065
 
1059
1066
  if (++i_next >= ubatches.size()) {
@@ -1063,7 +1070,7 @@ bool llama_memory_recurrent_state::next() {
1063
1070
  return true;
1064
1071
  }
1065
1072
 
1066
- bool llama_memory_recurrent_state::apply() {
1073
+ bool llama_memory_recurrent_context::apply() {
1067
1074
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1068
1075
 
1069
1076
  mem->find_slot(ubatches[i_next]);
@@ -1071,46 +1078,40 @@ bool llama_memory_recurrent_state::apply() {
1071
1078
  return true;
1072
1079
  }
1073
1080
 
1074
- std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
1075
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1076
-
1077
- return sbatch.out_ids;
1078
- }
1079
-
1080
- llama_memory_status llama_memory_recurrent_state::get_status() const {
1081
+ llama_memory_status llama_memory_recurrent_context::get_status() const {
1081
1082
  return status;
1082
1083
  }
1083
1084
 
1084
- const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
1085
+ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
1085
1086
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1086
1087
 
1087
1088
  return ubatches[i_next];
1088
1089
  }
1089
1090
 
1090
- uint32_t llama_memory_recurrent_state::get_n_rs() const {
1091
+ uint32_t llama_memory_recurrent_context::get_n_rs() const {
1091
1092
  return is_full ? mem->size : mem->n;
1092
1093
  }
1093
1094
 
1094
- uint32_t llama_memory_recurrent_state::get_head() const {
1095
+ uint32_t llama_memory_recurrent_context::get_head() const {
1095
1096
  return is_full ? 0 : mem->head;
1096
1097
  }
1097
1098
 
1098
- int32_t llama_memory_recurrent_state::get_rs_z() const {
1099
+ int32_t llama_memory_recurrent_context::get_rs_z() const {
1099
1100
  return is_full ? 0 : mem->rs_z;
1100
1101
  }
1101
1102
 
1102
- uint32_t llama_memory_recurrent_state::get_size() const {
1103
+ uint32_t llama_memory_recurrent_context::get_size() const {
1103
1104
  return mem->size;
1104
1105
  }
1105
1106
 
1106
- ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
1107
+ ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
1107
1108
  return mem->r_l[il];
1108
1109
  }
1109
1110
 
1110
- ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
1111
+ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
1111
1112
  return mem->s_l[il];
1112
1113
  }
1113
1114
 
1114
- int32_t llama_memory_recurrent_state::s_copy(int i) const {
1115
+ int32_t llama_memory_recurrent_context::s_copy(int i) const {
1115
1116
  return mem->cells[i + mem->head].src0;
1116
1117
  }