@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -17,11 +17,12 @@ struct ggml_tensor;
17
17
  struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- class llama_memory_state_i;
20
+ struct llama_memory_state_i;
21
21
 
22
22
  class llama_kv_cache_unified_state;
23
23
  class llama_kv_cache_unified_iswa_state;
24
- class llama_kv_cache_recurrent_state;
24
+ class llama_memory_recurrent_state;
25
+ class llama_memory_hybrid_state;
25
26
 
26
27
  // certain models (typically multi-modal) can produce different types of graphs
27
28
  enum llm_graph_type {
@@ -36,6 +37,7 @@ enum llm_ffn_op_type {
36
37
  LLM_FFN_RELU,
37
38
  LLM_FFN_RELU_SQR,
38
39
  LLM_FFN_SWIGLU,
40
+ LLM_FFN_GEGLU,
39
41
  };
40
42
 
41
43
  enum llm_ffn_gate_type {
@@ -187,28 +189,16 @@ public:
187
189
  const llama_cparams & cparams;
188
190
  };
189
191
 
190
- class llm_graph_input_s_copy : public llm_graph_input_i {
192
+ class llm_graph_input_rs : public llm_graph_input_i {
191
193
  public:
192
- llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
193
- virtual ~llm_graph_input_s_copy() = default;
194
+ llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
195
+ virtual ~llm_graph_input_rs() = default;
194
196
 
195
197
  void set_input(const llama_ubatch * ubatch) override;
196
198
 
197
199
  ggml_tensor * s_copy; // I32 [kv_size]
198
200
 
199
- const llama_kv_cache_recurrent_state * kv_state;
200
- };
201
-
202
- class llm_graph_input_s_mask : public llm_graph_input_i {
203
- public:
204
- llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
205
- virtual ~llm_graph_input_s_mask() = default;
206
-
207
- void set_input(const llama_ubatch * ubatch) override;
208
-
209
- ggml_tensor * s_mask; // F32 [1, n_kv]
210
-
211
- const llama_kv_cache_recurrent_state * kv_state;
201
+ const llama_memory_recurrent_state * mem_state;
212
202
  };
213
203
 
214
204
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -311,6 +301,33 @@ public:
311
301
  const llama_cross * cross = nullptr;
312
302
  };
313
303
 
304
+ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
305
+ public:
306
+ llm_graph_input_mem_hybrid(
307
+ const llama_hparams & hparams,
308
+ const llama_cparams & cparams,
309
+ const llama_memory_hybrid_state * mem_state) :
310
+ hparams(hparams),
311
+ cparams(cparams),
312
+ mem_state(mem_state) {
313
+ }
314
+ virtual ~llm_graph_input_mem_hybrid() = default;
315
+
316
+ void set_input(const llama_ubatch * ubatch) override;
317
+
318
+ ggml_tensor * s_copy; // I32 [kv_size]
319
+
320
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
321
+
322
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
323
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
324
+
325
+ const llama_hparams & hparams;
326
+ const llama_cparams & cparams;
327
+
328
+ const llama_memory_hybrid_state * mem_state;
329
+ };
330
+
314
331
  //
315
332
  // llm_graph_result
316
333
  //
@@ -389,7 +406,7 @@ struct llm_graph_params {
389
406
  const llama_memory_state_i * mstate;
390
407
  const llama_cross * cross;
391
408
 
392
- int32_t n_outputs;
409
+ uint32_t n_outputs;
393
410
 
394
411
  const llm_graph_cb & cb;
395
412
  };
@@ -423,8 +440,8 @@ struct llm_graph_context {
423
440
  const float norm_eps;
424
441
  const float norm_rms_eps;
425
442
 
426
- const int32_t n_tokens;
427
- const int32_t n_outputs;
443
+ const int64_t n_tokens;
444
+ const int64_t n_outputs;
428
445
  const int32_t n_ctx_orig; // yarn
429
446
 
430
447
  const enum llama_pooling_type pooling_type;
@@ -519,14 +536,14 @@ struct llm_graph_context {
519
536
  ggml_tensor * build_inp_out_ids() const;
520
537
  ggml_tensor * build_inp_mean() const;
521
538
  ggml_tensor * build_inp_cls() const;
522
- ggml_tensor * build_inp_s_copy() const;
523
- ggml_tensor * build_inp_s_mask() const;
524
539
 
525
540
  ggml_tensor * build_inp_cross_embd() const;
526
541
  ggml_tensor * build_inp_pos_bucket_enc() const;
527
542
  ggml_tensor * build_inp_pos_bucket_dec() const;
528
543
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
529
544
 
545
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
546
+
530
547
  //
531
548
  // attention
532
549
  //
@@ -601,23 +618,62 @@ struct llm_graph_context {
601
618
  float kq_scale,
602
619
  int il) const;
603
620
 
621
+ ggml_tensor * build_attn(
622
+ llm_graph_input_mem_hybrid * inp,
623
+ ggml_cgraph * gf,
624
+ ggml_tensor * wo,
625
+ ggml_tensor * wo_b,
626
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
627
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
628
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
629
+ ggml_tensor * kq_b,
630
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
631
+ float kq_scale,
632
+ int il) const;
604
633
  //
605
634
  // recurrent
606
635
  //
607
636
 
608
- ggml_tensor * build_copy_mask_state(
609
- ggml_cgraph * gf,
610
- ggml_tensor * s,
611
- ggml_tensor * state_copy,
612
- ggml_tensor * state_mask,
613
- int32_t n_state,
614
- int32_t n_seqs) const;
637
+ // TODO: avoid notion of "kv"
638
+ // TODO: move this implementation to llama_memory_recurrent.
639
+ // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
640
+ // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
641
+ // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
642
+ // `llama_memory_recurrent`
643
+ ggml_tensor * build_rs(
644
+ ggml_cgraph * gf,
645
+ ggml_tensor * s,
646
+ ggml_tensor * state_copy,
647
+ int32_t state_size,
648
+ int32_t n_seqs,
649
+ uint32_t n_kv,
650
+ uint32_t kv_head,
651
+ uint32_t kv_size,
652
+ int32_t rs_zero,
653
+ bool avoid_copies = false) const;
654
+
655
+ llm_graph_input_rs * build_rs_inp() const;
656
+
657
+ ggml_tensor * build_rs(
658
+ llm_graph_input_rs * inp,
659
+ ggml_cgraph * gf,
660
+ ggml_tensor * s,
661
+ int32_t state_size,
662
+ int32_t n_seqs,
663
+ bool avoid_copies = false) const;
664
+
665
+ ggml_tensor * build_rs(
666
+ llm_graph_input_mem_hybrid * inp,
667
+ ggml_cgraph * gf,
668
+ ggml_tensor * s,
669
+ int32_t state_size,
670
+ int32_t n_seqs,
671
+ bool avoid_copies = false) const;
615
672
 
616
673
  ggml_tensor * build_rwkv_token_shift_load(
617
- ggml_cgraph * gf,
618
- ggml_tensor * state_copy,
619
- ggml_tensor * state_mask,
620
- const llama_ubatch & ubatch,
674
+ llm_graph_input_rs * inp,
675
+ ggml_cgraph * gf,
676
+ const llama_ubatch & ubatch,
621
677
  int il) const;
622
678
 
623
679
  ggml_tensor * build_rwkv_token_shift_store(
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
65
65
  return n_embd_head_v * n_head_kv;
66
66
  }
67
67
 
68
- uint32_t llama_hparams::n_embd_k_s() const {
68
+ uint32_t llama_hparams::n_embd_r() const {
69
69
  if (wkv_head_size != 0) {
70
70
  // for RWKV models
71
71
  return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
76
76
  return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
77
77
  }
78
78
 
79
- uint32_t llama_hparams::n_embd_v_s() const {
79
+ uint32_t llama_hparams::n_embd_s() const {
80
80
  if (wkv_head_size != 0) {
81
81
  // corresponds to RWKV's wkv_states size
82
82
  return n_embd * wkv_head_size;
@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
86
86
  return ssm_d_state * ssm_d_inner;
87
87
  }
88
88
 
89
+ bool llama_hparams::is_recurrent(uint32_t il) const {
90
+ return recurrent_layer_arr[il];
91
+ }
92
+
89
93
  bool llama_hparams::is_swa(uint32_t il) const {
90
94
  if (il < n_layer) {
91
95
  return swa_layers[il];
@@ -115,6 +115,9 @@ struct llama_hparams {
115
115
  uint32_t ssm_d_state = 0;
116
116
  uint32_t ssm_dt_rank = 0;
117
117
 
118
+ // for hybrid state space models
119
+ std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120
+
118
121
  bool ssm_dt_b_c_rms = false;
119
122
 
120
123
  float f_clamp_kqv = 0.0f;
@@ -181,10 +184,13 @@ struct llama_hparams {
181
184
 
182
185
  // dimension of the rolling state embeddings
183
186
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
184
- uint32_t n_embd_k_s() const;
187
+ uint32_t n_embd_r() const;
185
188
 
186
189
  // dimension of the recurrent state embeddings
187
- uint32_t n_embd_v_s() const;
190
+ uint32_t n_embd_s() const;
191
+
192
+ // whether or not the given layer is recurrent (for hybrid models)
193
+ bool is_recurrent(uint32_t il) const;
188
194
 
189
195
  bool is_swa(uint32_t il) const;
190
196
  };
@@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
52
52
  hparams.n_swa, hparams.swa_type);
53
53
  }
54
54
 
55
- void llama_kv_cache_unified_iswa::clear() {
56
- kv_base->clear();
57
- kv_swa ->clear();
55
+ void llama_kv_cache_unified_iswa::clear(bool data) {
56
+ kv_base->clear(data);
57
+ kv_swa ->clear(data);
58
58
  }
59
59
 
60
60
  bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -95,54 +95,77 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
95
  return kv_swa->seq_pos_max(seq_id);
96
96
  }
97
97
 
98
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
99
- GGML_UNUSED(embd_pooled);
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99
+ GGML_UNUSED(embd_all);
100
100
 
101
- // TODO: if we fail with split_simple, we should attempt different splitting strategies
102
- // but to do that properly, we first have to refactor the batches to be more flexible
101
+ // first try simple split
102
+ do {
103
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
103
104
 
104
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
105
+ std::vector<llama_ubatch> ubatches;
105
106
 
106
- std::vector<llama_ubatch> ubatches;
107
+ while (sbatch.n_tokens > 0) {
108
+ auto ubatch = sbatch.split_simple(n_ubatch);
107
109
 
108
- while (sbatch.n_tokens > 0) {
109
- auto ubatch = sbatch.split_simple(n_ubatch);
110
+ ubatches.push_back(ubatch);
111
+ }
110
112
 
111
- ubatches.push_back(ubatch);
112
- }
113
+ auto heads_base = kv_base->prepare(ubatches);
114
+ if (heads_base.empty()) {
115
+ break;
116
+ }
113
117
 
114
- auto heads_base = kv_base->prepare(ubatches);
115
- if (heads_base.empty()) {
116
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
117
- }
118
+ auto heads_swa = kv_swa->prepare(ubatches);
119
+ if (heads_swa.empty()) {
120
+ break;
121
+ }
118
122
 
119
- auto heads_swa = kv_swa->prepare(ubatches);
120
- if (heads_swa.empty()) {
121
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
122
- }
123
+ assert(heads_base.size() == heads_swa.size());
123
124
 
124
- assert(heads_base.size() == heads_swa.size());
125
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
126
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
127
+ } while (false);
125
128
 
126
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
127
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
- }
129
+ // if it fails, try equal split
130
+ do {
131
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
129
132
 
130
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
132
- }
133
+ std::vector<llama_ubatch> ubatches;
133
134
 
134
- bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135
- bool res = false;
135
+ while (sbatch.n_tokens > 0) {
136
+ auto ubatch = sbatch.split_equal(n_ubatch);
136
137
 
137
- res = res | kv_base->update(lctx);
138
- res = res | kv_swa ->update(lctx);
138
+ ubatches.push_back(ubatch);
139
+ }
139
140
 
140
- return res;
141
+ auto heads_base = kv_base->prepare(ubatches);
142
+ if (heads_base.empty()) {
143
+ break;
144
+ }
145
+
146
+ auto heads_swa = kv_swa->prepare(ubatches);
147
+ if (heads_swa.empty()) {
148
+ break;
149
+ }
150
+
151
+ assert(heads_base.size() == heads_swa.size());
152
+
153
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
154
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
155
+ } while (false);
156
+
157
+ // TODO: if we fail again, we should attempt different splitting strategies
158
+ // but to do that properly, we first have to refactor the batches to be more flexible
159
+
160
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
161
+ }
162
+
163
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
164
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
141
165
  }
142
166
 
143
- void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144
- kv_base->defrag_sched(thold);
145
- kv_swa ->defrag_sched(thold);
167
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
168
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
146
169
  }
147
170
 
148
171
  bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -174,26 +197,34 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
174
197
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
175
198
 
176
199
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
177
- llama_memory_status status,
178
- llama_kv_cache_unified_iswa * kv) : status(status) {
179
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
200
+ llama_kv_cache_unified_iswa * kv) :
201
+ state_base(kv->get_base()->init_full()),
202
+ state_swa (kv->get_swa ()->init_full()),
203
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
204
+ }
205
+
206
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
207
+ llama_kv_cache_unified_iswa * kv,
208
+ llama_context * lctx,
209
+ bool optimize) :
210
+ state_base(kv->get_base()->init_update(lctx, optimize)),
211
+ state_swa (kv->get_swa ()->init_update(lctx, optimize)),
212
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
181
213
  }
182
214
 
183
215
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184
- llama_memory_status status,
185
216
  llama_kv_cache_unified_iswa * kv,
186
217
  llama_sbatch sbatch,
187
218
  std::vector<uint32_t> heads_base,
188
219
  std::vector<uint32_t> heads_swa,
189
- std::vector<llama_ubatch> ubatches)
190
- : status(status),
220
+ std::vector<llama_ubatch> ubatches) :
191
221
  sbatch(std::move(sbatch)),
192
- ubatches(std::move(ubatches)) {
193
- // note: here we copy the ubatches. not sure if this is ideal
194
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196
- }
222
+ ubatches(std::move(ubatches)),
223
+ // note: here we copy the ubatches. not sure if this is ideal
224
+ state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
225
+ state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
226
+ status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
227
+ }
197
228
 
198
229
  llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
199
230
 
@@ -233,17 +264,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
233
264
 
234
265
  const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
235
266
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
267
+
236
268
  return ubatches[i_next];
237
269
  }
238
270
 
239
271
  const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
240
272
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
241
273
 
242
- return state_base.get();
274
+ return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
243
275
  }
244
276
 
245
277
  const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
246
278
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247
279
 
248
- return state_swa.get();
280
+ return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
249
281
  }
@@ -11,7 +11,7 @@
11
11
  // utilizes two instances of llama_kv_cache_unified
12
12
  // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
13
 
14
- class llama_kv_cache_unified_iswa : public llama_kv_cache {
14
+ class llama_kv_cache_unified_iswa : public llama_memory_i {
15
15
  public:
16
16
  llama_kv_cache_unified_iswa(
17
17
  const llama_model & model,
@@ -31,7 +31,18 @@ public:
31
31
  // llama_memory_i
32
32
  //
33
33
 
34
- void clear() override;
34
+ llama_memory_state_ptr init_batch(
35
+ const llama_batch & batch,
36
+ uint32_t n_ubatch,
37
+ bool embd_all) override;
38
+
39
+ llama_memory_state_ptr init_full() override;
40
+
41
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
42
+
43
+ bool get_can_shift() const override;
44
+
45
+ void clear(bool data) override;
35
46
 
36
47
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
37
48
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -42,24 +53,6 @@ public:
42
53
  llama_pos seq_pos_min(llama_seq_id seq_id) const override;
43
54
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
44
55
 
45
- //
46
- // llama_kv_cache
47
- //
48
-
49
- llama_memory_state_ptr init_batch(
50
- const llama_batch & batch,
51
- uint32_t n_ubatch,
52
- bool embd_pooled,
53
- bool logits_all) override;
54
-
55
- llama_memory_state_ptr init_full() override;
56
-
57
- bool update(llama_context & lctx) override;
58
-
59
- void defrag_sched(float thold) override;
60
-
61
- bool get_can_shift() const override;
62
-
63
56
  // state write/load
64
57
 
65
58
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -86,12 +79,16 @@ public:
86
79
 
87
80
  // used to create a full-cache state
88
81
  llama_kv_cache_unified_iswa_state(
89
- llama_memory_status status,
90
82
  llama_kv_cache_unified_iswa * kv);
91
83
 
84
+ // used to create an update state
85
+ llama_kv_cache_unified_iswa_state(
86
+ llama_kv_cache_unified_iswa * kv,
87
+ llama_context * lctx,
88
+ bool optimize);
89
+
92
90
  // used to create a state from a batch
93
91
  llama_kv_cache_unified_iswa_state(
94
- llama_memory_status status,
95
92
  llama_kv_cache_unified_iswa * kv,
96
93
  llama_sbatch sbatch,
97
94
  std::vector<uint32_t> heads_base,
@@ -120,8 +117,6 @@ public:
120
117
  const llama_kv_cache_unified_state * get_swa() const;
121
118
 
122
119
  private:
123
- const llama_memory_status status;
124
-
125
120
  //llama_kv_cache_unified_iswa * kv;
126
121
 
127
122
  llama_sbatch sbatch;
@@ -131,6 +126,8 @@ private:
131
126
 
132
127
  std::vector<llama_ubatch> ubatches;
133
128
 
134
- std::unique_ptr<llama_kv_cache_unified_state> state_base;
135
- std::unique_ptr<llama_kv_cache_unified_state> state_swa;
129
+ const llama_memory_state_ptr state_base;
130
+ const llama_memory_state_ptr state_swa;
131
+
132
+ const llama_memory_status status;
136
133
  };