@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -17,12 +17,12 @@ struct ggml_tensor;
17
17
  struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- struct llama_memory_state_i;
20
+ struct llama_memory_context_i;
21
21
 
22
- class llama_kv_cache_unified_state;
23
- class llama_kv_cache_unified_iswa_state;
24
- class llama_memory_recurrent_state;
25
- class llama_memory_hybrid_state;
22
+ class llama_kv_cache_unified_context;
23
+ class llama_kv_cache_unified_iswa_context;
24
+ class llama_memory_recurrent_context;
25
+ class llama_memory_hybrid_context;
26
26
 
27
27
  // certain models (typically multi-modal) can produce different types of graphs
28
28
  enum llm_graph_type {
@@ -95,14 +95,14 @@ public:
95
95
 
96
96
  class llm_graph_input_pos : public llm_graph_input_i {
97
97
  public:
98
- llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
98
+ llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
99
99
  virtual ~llm_graph_input_pos() = default;
100
100
 
101
101
  void set_input(const llama_ubatch * ubatch) override;
102
102
 
103
103
  ggml_tensor * pos = nullptr; // I32 [n_batch]
104
104
 
105
- const int64_t n_pos_per_embd = 1;
105
+ const uint32_t n_pos_per_embd = 1;
106
106
  };
107
107
 
108
108
  // temperature tuning, used by llama4
@@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
136
136
  public:
137
137
  llm_graph_input_pos_bucket_kv(
138
138
  const llama_hparams & hparams,
139
- const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
139
+ const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
140
140
  virtual ~llm_graph_input_pos_bucket_kv() = default;
141
141
 
142
142
  void set_input(const llama_ubatch * ubatch) override;
@@ -144,7 +144,8 @@ public:
144
144
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
145
145
 
146
146
  const llama_hparams & hparams;
147
- const llama_kv_cache_unified_state * kv_state;
147
+
148
+ const llama_kv_cache_unified_context * mctx;
148
149
  };
149
150
 
150
151
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -191,14 +192,14 @@ public:
191
192
 
192
193
  class llm_graph_input_rs : public llm_graph_input_i {
193
194
  public:
194
- llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
195
+ llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
195
196
  virtual ~llm_graph_input_rs() = default;
196
197
 
197
198
  void set_input(const llama_ubatch * ubatch) override;
198
199
 
199
200
  ggml_tensor * s_copy; // I32 [kv_size]
200
201
 
201
- const llama_memory_recurrent_state * mem_state;
202
+ const llama_memory_recurrent_context * mctx;
202
203
  };
203
204
 
204
205
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -238,10 +239,10 @@ public:
238
239
  llm_graph_input_attn_kv_unified(
239
240
  const llama_hparams & hparams,
240
241
  const llama_cparams & cparams,
241
- const llama_kv_cache_unified_state * kv_state) :
242
+ const llama_kv_cache_unified_context * mctx) :
242
243
  hparams(hparams),
243
244
  cparams(cparams),
244
- kv_state(kv_state) {
245
+ mctx(mctx) {
245
246
  }
246
247
  ~llm_graph_input_attn_kv_unified() = default;
247
248
 
@@ -255,7 +256,7 @@ public:
255
256
  const llama_hparams & hparams;
256
257
  const llama_cparams & cparams;
257
258
 
258
- const llama_kv_cache_unified_state * kv_state;
259
+ const llama_kv_cache_unified_context * mctx;
259
260
  };
260
261
 
261
262
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -263,10 +264,10 @@ public:
263
264
  llm_graph_input_attn_kv_unified_iswa(
264
265
  const llama_hparams & hparams,
265
266
  const llama_cparams & cparams,
266
- const llama_kv_cache_unified_iswa_state * kv_state) :
267
+ const llama_kv_cache_unified_iswa_context * mctx) :
267
268
  hparams(hparams),
268
269
  cparams(cparams),
269
- kv_state(kv_state) {
270
+ mctx(mctx) {
270
271
  }
271
272
  ~llm_graph_input_attn_kv_unified_iswa() = default;
272
273
 
@@ -283,7 +284,7 @@ public:
283
284
  const llama_hparams & hparams;
284
285
  const llama_cparams & cparams;
285
286
 
286
- const llama_kv_cache_unified_iswa_state * kv_state;
287
+ const llama_kv_cache_unified_iswa_context * mctx;
287
288
  };
288
289
 
289
290
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -306,10 +307,10 @@ public:
306
307
  llm_graph_input_mem_hybrid(
307
308
  const llama_hparams & hparams,
308
309
  const llama_cparams & cparams,
309
- const llama_memory_hybrid_state * mem_state) :
310
+ const llama_memory_hybrid_context * mctx) :
310
311
  hparams(hparams),
311
312
  cparams(cparams),
312
- mem_state(mem_state) {
313
+ mctx(mctx) {
313
314
  }
314
315
  virtual ~llm_graph_input_mem_hybrid() = default;
315
316
 
@@ -325,7 +326,18 @@ public:
325
326
  const llama_hparams & hparams;
326
327
  const llama_cparams & cparams;
327
328
 
328
- const llama_memory_hybrid_state * mem_state;
329
+ const llama_memory_hybrid_context * mctx;
330
+ };
331
+
332
+ // TODO: remove this when ggml_scale_add is implemented
333
+ class llm_graph_input_one : public llm_graph_input_i {
334
+ public:
335
+ llm_graph_input_one() {}
336
+ virtual ~llm_graph_input_one() = default;
337
+
338
+ void set_input(const llama_ubatch *) override;
339
+
340
+ ggml_tensor * one = nullptr; // F32
329
341
  };
330
342
 
331
343
  //
@@ -401,10 +413,10 @@ struct llm_graph_params {
401
413
  ggml_backend_sched_t sched;
402
414
  ggml_backend_t backend_cpu;
403
415
 
404
- const llama_adapter_cvec * cvec;
405
- const llama_adapter_loras * loras;
406
- const llama_memory_state_i * mstate;
407
- const llama_cross * cross;
416
+ const llama_adapter_cvec * cvec;
417
+ const llama_adapter_loras * loras;
418
+ const llama_memory_context_i * mctx;
419
+ const llama_cross * cross;
408
420
 
409
421
  uint32_t n_outputs;
410
422
 
@@ -453,18 +465,17 @@ struct llm_graph_context {
453
465
 
454
466
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
455
467
 
456
- const llama_adapter_cvec * cvec;
457
- const llama_adapter_loras * loras;
458
- const llama_memory_state_i * mstate;
459
- const llama_cross * cross;
468
+ const llama_adapter_cvec * cvec;
469
+ const llama_adapter_loras * loras;
470
+ const llama_memory_context_i * mctx;
471
+ const llama_cross * cross;
460
472
 
461
473
  const llm_graph_cb & cb_func;
462
474
 
463
475
  std::unique_ptr<llm_graph_result> res;
464
476
 
465
477
  llm_graph_context(const llm_graph_params & params);
466
-
467
- int64_t n_pos_per_embd() const;
478
+ virtual ~llm_graph_context() = default;
468
479
 
469
480
  void cb(ggml_tensor * cur, const char * name, int il) const;
470
481
 
@@ -590,14 +601,15 @@ struct llm_graph_context {
590
601
 
591
602
  llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
592
603
 
604
+ // note: if k_cur or v_cur are not provided, they will not be stored in the memory
593
605
  ggml_tensor * build_attn(
594
606
  llm_graph_input_attn_kv_unified_iswa * inp,
595
607
  ggml_cgraph * gf,
596
608
  ggml_tensor * wo,
597
609
  ggml_tensor * wo_b,
598
610
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
599
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
600
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
611
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
612
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
601
613
  ggml_tensor * kq_b,
602
614
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
603
615
  float kq_scale,
@@ -90,6 +90,10 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
90
90
  return recurrent_layer_arr[il];
91
91
  }
92
92
 
93
+ uint32_t llama_hparams::n_pos_per_embd() const {
94
+ return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
95
+ }
96
+
93
97
  bool llama_hparams::is_swa(uint32_t il) const {
94
98
  if (il < n_layer) {
95
99
  return swa_layers[il];
@@ -143,6 +143,12 @@ struct llama_hparams {
143
143
  uint32_t n_attn_temp_floor_scale = 8192;
144
144
  float f_attn_temp_scale = 0.1;
145
145
 
146
+ // gemma3n altup
147
+ uint32_t n_altup = 4; // altup_num_inputs
148
+ uint32_t i_altup_act = 0; // altup_active_idx
149
+ uint32_t laurel_rank = 64;
150
+ uint32_t n_embd_altup = 256;
151
+
146
152
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
147
153
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
148
154
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@@ -192,6 +198,8 @@ struct llama_hparams {
192
198
  // whether or not the given layer is recurrent (for hybrid models)
193
199
  bool is_recurrent(uint32_t il) const;
194
200
 
201
+ uint32_t n_pos_per_embd() const;
202
+
195
203
  bool is_swa(uint32_t il) const;
196
204
  };
197
205
 
@@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
95
  return kv_swa->seq_pos_max(seq_id);
96
96
  }
97
97
 
98
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
98
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99
99
  GGML_UNUSED(embd_all);
100
100
 
101
101
  // first try simple split
102
102
  do {
103
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
103
+ balloc.split_reset();
104
104
 
105
105
  std::vector<llama_ubatch> ubatches;
106
+ while (true) {
107
+ auto ubatch = balloc.split_simple(n_ubatch);
106
108
 
107
- while (sbatch.n_tokens > 0) {
108
- auto ubatch = sbatch.split_simple(n_ubatch);
109
+ if (ubatch.n_tokens == 0) {
110
+ break;
111
+ }
109
112
 
110
- ubatches.push_back(ubatch);
113
+ ubatches.push_back(std::move(ubatch)); // NOLINT
111
114
  }
112
115
 
113
116
  auto heads_base = kv_base->prepare(ubatches);
@@ -122,20 +125,23 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
122
125
 
123
126
  assert(heads_base.size() == heads_swa.size());
124
127
 
125
- return std::make_unique<llama_kv_cache_unified_iswa_state>(
126
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
129
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
127
130
  } while (false);
128
131
 
129
132
  // if it fails, try equal split
130
133
  do {
131
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
134
+ balloc.split_reset();
132
135
 
133
136
  std::vector<llama_ubatch> ubatches;
137
+ while (true) {
138
+ auto ubatch = balloc.split_equal(n_ubatch);
134
139
 
135
- while (sbatch.n_tokens > 0) {
136
- auto ubatch = sbatch.split_equal(n_ubatch);
140
+ if (ubatch.n_tokens == 0) {
141
+ break;
142
+ }
137
143
 
138
- ubatches.push_back(ubatch);
144
+ ubatches.push_back(std::move(ubatch)); // NOLINT
139
145
  }
140
146
 
141
147
  auto heads_base = kv_base->prepare(ubatches);
@@ -150,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
150
156
 
151
157
  assert(heads_base.size() == heads_swa.size());
152
158
 
153
- return std::make_unique<llama_kv_cache_unified_iswa_state>(
154
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
159
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
160
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
155
161
  } while (false);
156
162
 
157
163
  // TODO: if we fail again, we should attempt different splitting strategies
158
164
  // but to do that properly, we first have to refactor the batches to be more flexible
159
165
 
160
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
166
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
161
167
  }
162
168
 
163
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
164
- return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
169
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
170
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
165
171
  }
166
172
 
167
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
168
- return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
173
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
174
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
169
175
  }
170
176
 
171
177
  bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -191,48 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
191
197
  }
192
198
 
193
199
  //
194
- // llama_kv_cache_unified_iswa_state
200
+ // llama_kv_cache_unified_iswa_context
195
201
  //
196
202
 
197
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
203
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
198
204
 
199
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
205
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
200
206
  llama_kv_cache_unified_iswa * kv) :
201
- state_base(kv->get_base()->init_full()),
202
- state_swa (kv->get_swa ()->init_full()),
203
- status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
207
+ ctx_base(kv->get_base()->init_full()),
208
+ ctx_swa (kv->get_swa ()->init_full()),
209
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
204
210
  }
205
211
 
206
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
212
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
207
213
  llama_kv_cache_unified_iswa * kv,
208
214
  llama_context * lctx,
209
215
  bool optimize) :
210
- state_base(kv->get_base()->init_update(lctx, optimize)),
211
- state_swa (kv->get_swa ()->init_update(lctx, optimize)),
212
- status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
216
+ ctx_base(kv->get_base()->init_update(lctx, optimize)),
217
+ ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
218
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
213
219
  }
214
220
 
215
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
221
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
216
222
  llama_kv_cache_unified_iswa * kv,
217
- llama_sbatch sbatch,
218
223
  std::vector<uint32_t> heads_base,
219
224
  std::vector<uint32_t> heads_swa,
220
225
  std::vector<llama_ubatch> ubatches) :
221
- sbatch(std::move(sbatch)),
222
226
  ubatches(std::move(ubatches)),
223
227
  // note: here we copy the ubatches. not sure if this is ideal
224
- state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
225
- state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
226
- status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
228
+ ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229
+ ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
227
231
  }
228
232
 
229
- llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
233
+ llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
230
234
 
231
- bool llama_kv_cache_unified_iswa_state::next() {
235
+ bool llama_kv_cache_unified_iswa_context::next() {
232
236
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
233
237
 
234
- state_base->next();
235
- state_swa ->next();
238
+ ctx_base->next();
239
+ ctx_swa ->next();
236
240
 
237
241
  if (++i_next >= ubatches.size()) {
238
242
  return false;
@@ -241,41 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
241
245
  return true;
242
246
  }
243
247
 
244
- bool llama_kv_cache_unified_iswa_state::apply() {
248
+ bool llama_kv_cache_unified_iswa_context::apply() {
245
249
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
246
250
 
247
251
  bool res = true;
248
252
 
249
- res = res & state_base->apply();
250
- res = res & state_swa ->apply();
253
+ res = res & ctx_base->apply();
254
+ res = res & ctx_swa ->apply();
251
255
 
252
256
  return res;
253
257
  }
254
258
 
255
- std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
256
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
257
-
258
- return sbatch.out_ids;
259
- }
260
-
261
- llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
259
+ llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
262
260
  return status;
263
261
  }
264
262
 
265
- const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
263
+ const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
266
264
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
267
265
 
268
266
  return ubatches[i_next];
269
267
  }
270
268
 
271
- const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
269
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
272
270
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
273
271
 
274
- return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
272
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
275
273
  }
276
274
 
277
- const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
275
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
278
276
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
279
277
 
280
- return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
278
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
281
279
  }
@@ -31,14 +31,14 @@ public:
31
31
  // llama_memory_i
32
32
  //
33
33
 
34
- llama_memory_state_ptr init_batch(
35
- const llama_batch & batch,
34
+ llama_memory_context_ptr init_batch(
35
+ llama_batch_allocr & balloc,
36
36
  uint32_t n_ubatch,
37
37
  bool embd_all) override;
38
38
 
39
- llama_memory_state_ptr init_full() override;
39
+ llama_memory_context_ptr init_full() override;
40
40
 
41
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
41
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
42
42
 
43
43
  bool get_can_shift() const override;
44
44
 
@@ -72,62 +72,57 @@ private:
72
72
  std::unique_ptr<llama_kv_cache_unified> kv_swa;
73
73
  };
74
74
 
75
- class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
75
+ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
76
76
  public:
77
77
  // used for errors
78
- llama_kv_cache_unified_iswa_state(llama_memory_status status);
78
+ llama_kv_cache_unified_iswa_context(llama_memory_status status);
79
79
 
80
- // used to create a full-cache state
81
- llama_kv_cache_unified_iswa_state(
80
+ // used to create a full-cache context
81
+ llama_kv_cache_unified_iswa_context(
82
82
  llama_kv_cache_unified_iswa * kv);
83
83
 
84
- // used to create an update state
85
- llama_kv_cache_unified_iswa_state(
84
+ // used to create an update context
85
+ llama_kv_cache_unified_iswa_context(
86
86
  llama_kv_cache_unified_iswa * kv,
87
87
  llama_context * lctx,
88
88
  bool optimize);
89
89
 
90
- // used to create a state from a batch
91
- llama_kv_cache_unified_iswa_state(
90
+ // used to create a batch processing context from a batch
91
+ llama_kv_cache_unified_iswa_context(
92
92
  llama_kv_cache_unified_iswa * kv,
93
- llama_sbatch sbatch,
94
93
  std::vector<uint32_t> heads_base,
95
94
  std::vector<uint32_t> heads_swa,
96
95
  std::vector<llama_ubatch> ubatches);
97
96
 
98
- virtual ~llama_kv_cache_unified_iswa_state();
97
+ virtual ~llama_kv_cache_unified_iswa_context();
99
98
 
100
99
  //
101
- // llama_memory_state_i
100
+ // llama_memory_context_i
102
101
  //
103
102
 
104
103
  bool next() override;
105
104
  bool apply() override;
106
105
 
107
- std::vector<int64_t> & out_ids() override;
108
-
109
106
  llama_memory_status get_status() const override;
110
107
  const llama_ubatch & get_ubatch() const override;
111
108
 
112
109
  //
113
- // llama_kv_cache_unified_iswa_state specific API
110
+ // llama_kv_cache_unified_iswa_context specific API
114
111
  //
115
112
 
116
- const llama_kv_cache_unified_state * get_base() const;
117
- const llama_kv_cache_unified_state * get_swa() const;
113
+ const llama_kv_cache_unified_context * get_base() const;
114
+ const llama_kv_cache_unified_context * get_swa() const;
118
115
 
119
116
  private:
120
117
  //llama_kv_cache_unified_iswa * kv;
121
118
 
122
- llama_sbatch sbatch;
123
-
124
119
  // the index of the next ubatch to process
125
120
  size_t i_next = 0;
126
121
 
127
122
  std::vector<llama_ubatch> ubatches;
128
123
 
129
- const llama_memory_state_ptr state_base;
130
- const llama_memory_state_ptr state_swa;
124
+ const llama_memory_context_ptr ctx_base;
125
+ const llama_memory_context_ptr ctx_swa;
131
126
 
132
127
  const llama_memory_status status;
133
128
  };