@novastera-oss/llamarn 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakePresets.json +11 -0
  22. package/cpp/llama.cpp/CODEOWNERS +1 -0
  23. package/cpp/llama.cpp/README.md +4 -3
  24. package/cpp/llama.cpp/common/arg.cpp +45 -1
  25. package/cpp/llama.cpp/common/common.cpp +22 -6
  26. package/cpp/llama.cpp/common/common.h +18 -4
  27. package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
  28. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
  30. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  31. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  32. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  34. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
  35. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
  77. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
  78. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
  79. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
  109. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  115. package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
  116. package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
  117. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  118. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
  120. package/cpp/llama.cpp/include/llama.h +15 -7
  121. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  122. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  123. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  124. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  125. package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
  126. package/cpp/llama.cpp/src/llama-arch.h +5 -0
  127. package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
  128. package/cpp/llama.cpp/src/llama-batch.h +24 -18
  129. package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
  130. package/cpp/llama.cpp/src/llama-chat.h +2 -0
  131. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  132. package/cpp/llama.cpp/src/llama-context.h +26 -16
  133. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  134. package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
  135. package/cpp/llama.cpp/src/llama-graph.h +147 -72
  136. package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
  137. package/cpp/llama.cpp/src/llama-hparams.h +10 -2
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  139. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  140. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  141. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  142. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
  144. package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
  145. package/cpp/llama.cpp/src/llama-model.h +3 -4
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
  148. package/cpp/llama.cpp/src/llama-vocab.h +2 -0
  149. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  150. package/cpp/llama.cpp/src/unicode.h +2 -0
  151. package/ios/include/common.h +18 -4
  152. package/ios/include/llama.h +15 -7
  153. package/ios/libs/llama.xcframework/Info.plist +15 -15
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  155. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  158. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  165. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  172. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  174. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
  175. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  176. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  177. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  178. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  179. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  180. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
  183. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
  184. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  186. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
  187. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
  188. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  189. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  190. package/package.json +4 -4
@@ -1,6 +1,7 @@
1
1
  #pragma once
2
2
 
3
3
  #include "llama-arch.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-hparams.h"
5
6
  #include "llama-adapter.h"
6
7
 
@@ -14,7 +15,6 @@ struct ggml_cgraph;
14
15
  struct ggml_context;
15
16
  struct ggml_tensor;
16
17
 
17
- struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
20
  struct llama_memory_context_i;
@@ -69,6 +69,8 @@ struct llama_cross {
69
69
  std::vector<std::set<llama_seq_id>> seq_ids_enc;
70
70
  };
71
71
 
72
+ struct llm_graph_params;
73
+
72
74
  //
73
75
  // llm_graph_input
74
76
  //
@@ -78,11 +80,19 @@ public:
78
80
  virtual ~llm_graph_input_i() = default;
79
81
 
80
82
  virtual void set_input(const llama_ubatch * ubatch) = 0;
83
+
84
+ // return true if the resulting input tensors using the provided graph parameters would be
85
+ // the same as the previous input tensors that we have currently stored in the object
86
+ virtual bool can_reuse(const llm_graph_params & params) {
87
+ // returning false here by default will prevent from reusing the graph if the check
88
+ // for the input type has not been implemented yet
89
+ GGML_UNUSED(params);
90
+ return false;
91
+ }
81
92
  };
82
93
 
83
94
  using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
84
95
 
85
-
86
96
  class llm_graph_input_embd : public llm_graph_input_i {
87
97
  public:
88
98
  llm_graph_input_embd() = default;
@@ -90,6 +100,8 @@ public:
90
100
 
91
101
  void set_input(const llama_ubatch * ubatch) override;
92
102
 
103
+ bool can_reuse(const llm_graph_params & params) override;
104
+
93
105
  ggml_tensor * tokens = nullptr; // I32 [n_batch]
94
106
  ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
95
107
  };
@@ -101,6 +113,8 @@ public:
101
113
 
102
114
  void set_input(const llama_ubatch * ubatch) override;
103
115
 
116
+ bool can_reuse(const llm_graph_params & params) override;
117
+
104
118
  ggml_tensor * pos = nullptr; // I32 [n_batch]
105
119
 
106
120
  const uint32_t n_pos_per_embd = 1;
@@ -154,17 +168,19 @@ public:
154
168
  llm_graph_input_out_ids(
155
169
  const llama_hparams & hparams,
156
170
  const llama_cparams & cparams,
157
- int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
171
+ uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
158
172
  virtual ~llm_graph_input_out_ids() = default;
159
173
 
160
174
  void set_input(const llama_ubatch * ubatch) override;
161
175
 
176
+ bool can_reuse(const llm_graph_params & params) override;
177
+
162
178
  ggml_tensor * out_ids; // I32 [n_outputs]
163
179
 
164
180
  const llama_hparams & hparams;
165
181
  const llama_cparams & cparams;
166
182
 
167
- const int32_t n_outputs;
183
+ const uint32_t n_outputs;
168
184
  };
169
185
 
170
186
  class llm_graph_input_mean : public llm_graph_input_i {
@@ -249,16 +265,18 @@ public:
249
265
 
250
266
  void set_input(const llama_ubatch * ubatch) override;
251
267
 
268
+ bool can_reuse(const llm_graph_params & params) override;
269
+
252
270
  ggml_tensor * get_k_idxs() const { return self_k_idxs; }
253
271
  ggml_tensor * get_v_idxs() const { return self_v_idxs; }
254
272
 
255
273
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
256
274
 
257
275
  ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258
- ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
276
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
259
277
 
260
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
261
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
278
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
279
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
262
280
 
263
281
  const llama_hparams & hparams;
264
282
  const llama_cparams & cparams;
@@ -280,6 +298,8 @@ public:
280
298
 
281
299
  void set_input(const llama_ubatch * ubatch) override;
282
300
 
301
+ bool can_reuse(const llm_graph_params & params) override;
302
+
283
303
  ggml_tensor * get_k_idxs() const { return self_k_idxs; }
284
304
  ggml_tensor * get_v_idxs() const { return self_v_idxs; }
285
305
  ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
@@ -289,14 +309,14 @@ public:
289
309
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
290
310
 
291
311
  ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
292
- ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
312
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
293
313
  ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294
- ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
314
+ ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
295
315
 
296
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
297
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
298
- ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
299
- ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
316
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
317
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
318
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
319
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
300
320
 
301
321
  const llama_hparams & hparams;
302
322
  const llama_cparams & cparams;
@@ -351,40 +371,108 @@ public:
351
371
  // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
352
372
  // these are used by the llama_context to extact the relevant data, based on the compute parameters
353
373
 
354
- class llm_graph_result_i {
355
- public:
356
- virtual ~llm_graph_result_i() = default;
374
+ // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
375
+ using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
357
376
 
358
- virtual ggml_tensor * get_tokens() = 0;
359
- virtual ggml_tensor * get_logits() = 0;
360
- virtual ggml_tensor * get_embd() = 0;
361
- virtual ggml_tensor * get_embd_pooled() = 0;
377
+ class llm_graph_result;
362
378
 
363
- virtual void set_inputs(const llama_ubatch * ubatch) = 0;
364
- };
379
+ struct llm_graph_params {
380
+ llm_arch arch = LLM_ARCH_UNKNOWN;
365
381
 
366
- using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
382
+ llama_hparams hparams;
383
+ llama_cparams cparams;
367
384
 
385
+ llama_ubatch ubatch; // note: intentionally make a copy
368
386
 
369
- class llm_graph_result : public llm_graph_result_i {
370
- public:
371
- virtual ~llm_graph_result() = default;
387
+ llm_graph_type gtype;
372
388
 
373
- ggml_tensor * get_tokens() override { return t_tokens; }
374
- ggml_tensor * get_logits() override { return t_logits; }
375
- ggml_tensor * get_embd() override { return t_embd; }
376
- ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
389
+ ggml_backend_sched_t sched;
390
+ ggml_backend_t backend_cpu;
377
391
 
378
- void set_inputs(const llama_ubatch * ubatch) override {
379
- for (auto & input : inputs) {
380
- input->set_input(ubatch);
392
+ const llama_adapter_cvec * cvec;
393
+ const llama_adapter_loras * loras;
394
+ const llama_memory_context_i * mctx;
395
+ const llama_cross * cross;
396
+
397
+ uint32_t n_outputs;
398
+
399
+ llm_graph_cb cb;
400
+
401
+ llm_graph_result * res;
402
+
403
+ // return true if the "other" params would result in a graph with the same topology as with the current params
404
+ // having the same topology allows us to reuse the graph in some cases
405
+ bool allow_reuse(const llm_graph_params & other) const {
406
+ // first check the ubatch
407
+ bool can_reuse_ubatch =
408
+ ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
409
+ ubatch.n_tokens == other.ubatch.n_tokens &&
410
+ ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
411
+ ubatch.n_seqs == other.ubatch.n_seqs &&
412
+ ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
413
+ (
414
+ (!ubatch.token && !other.ubatch.token) ||
415
+ (!ubatch.embd && !other.ubatch.embd)
416
+ );
417
+
418
+ if (can_reuse_ubatch && !ubatch.equal_seqs()) {
419
+ if (!ubatch.data) {
420
+ // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
421
+ // therefore we cannot perform the sequence id check. normally should never happen
422
+ can_reuse_ubatch = false;
423
+ } else {
424
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
425
+ can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
426
+ }
427
+ }
381
428
  }
382
- }
383
429
 
384
- llm_graph_input_i * add_input(llm_graph_input_ptr input) {
385
- inputs.emplace_back(std::move(input));
386
- return inputs.back().get();
430
+ if (!can_reuse_ubatch) {
431
+ return false;
432
+ }
433
+
434
+ return
435
+ cparams.embeddings == other.cparams.embeddings &&
436
+ cparams.causal_attn == other.cparams.causal_attn &&
437
+ arch == other.arch &&
438
+ gtype == other.gtype &&
439
+ cvec == other.cvec &&
440
+ loras == other.loras &&
441
+ cross == other.cross &&
442
+ n_outputs == other.n_outputs;
387
443
  }
444
+ };
445
+
446
+ class llm_graph_result {
447
+ public:
448
+ llm_graph_result(int64_t max_nodes);
449
+
450
+ virtual ~llm_graph_result() = default;
451
+
452
+ ggml_tensor * get_tokens() const { return t_tokens; }
453
+ ggml_tensor * get_logits() const { return t_logits; }
454
+ ggml_tensor * get_embd() const { return t_embd; }
455
+ ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
456
+
457
+ ggml_cgraph * get_gf() const { return gf; }
458
+ ggml_context * get_ctx() const { return ctx_compute.get(); }
459
+
460
+ int64_t get_max_nodes() const;
461
+
462
+ void reset();
463
+
464
+ void set_inputs(const llama_ubatch * ubatch);
465
+
466
+ // try to update the existing graph result using the new graph parameters in order to reuse it
467
+ // this can only be done if we determine that the resulting graph using the new graph parameters
468
+ // would be identical to the existing graph. in that case, we simply have to update the memory
469
+ // contexts of the input tensors of the graph and we can reuse it for another computation
470
+ // return true if the graph was updated and can be reused
471
+ bool can_reuse(const llm_graph_params & params);
472
+
473
+ llm_graph_input_i * add_input(llm_graph_input_ptr input);
474
+
475
+ void set_params(const llm_graph_params & params);
388
476
 
389
477
  // important graph nodes
390
478
  ggml_tensor * t_tokens = nullptr;
@@ -393,36 +481,31 @@ public:
393
481
  ggml_tensor * t_embd_pooled = nullptr;
394
482
 
395
483
  std::vector<llm_graph_input_ptr> inputs;
396
- };
397
484
 
398
- //
399
- // llm_graph_context
400
- //
485
+ ggml_context_ptr ctx_compute;
401
486
 
402
- // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
403
- using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
487
+ // memory buffers used to evaluate the model
488
+ std::vector<uint8_t> buf_compute_meta;
404
489
 
405
- struct llm_graph_params {
406
- ggml_context * ctx;
490
+ ggml_cgraph * gf;
407
491
 
408
- const llm_arch arch;
492
+ int64_t max_nodes;
409
493
 
410
- const llama_hparams & hparams;
411
- const llama_cparams & cparams;
412
- const llama_ubatch & ubatch;
494
+ private:
495
+ // keep a copy of the previous graph parameters
496
+ // we will use this to determine whether the graph can be reused by comparing them with the new parameters
497
+ // note: these are updated after constructing the new graph
498
+ llm_graph_params params;
413
499
 
414
- ggml_backend_sched_t sched;
415
- ggml_backend_t backend_cpu;
416
-
417
- const llama_adapter_cvec * cvec;
418
- const llama_adapter_loras * loras;
419
- const llama_memory_context_i * mctx;
420
- const llama_cross * cross;
500
+ // env: LLAMA_GRAPH_RESULT_DEBUG
501
+ int debug = 0;
502
+ };
421
503
 
422
- uint32_t n_outputs;
504
+ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
423
505
 
424
- const llm_graph_cb & cb;
425
- };
506
+ //
507
+ // llm_graph_context
508
+ //
426
509
 
427
510
  // used in build_rs to properly order writes and avoid unnecessary copies
428
511
  using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
@@ -463,8 +546,6 @@ struct llm_graph_context {
463
546
  const enum llama_pooling_type pooling_type;
464
547
  const enum llama_rope_type rope_type;
465
548
 
466
- ggml_context * ctx0 = nullptr;
467
-
468
549
  ggml_backend_sched_t sched;
469
550
 
470
551
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
@@ -476,7 +557,10 @@ struct llm_graph_context {
476
557
 
477
558
  const llm_graph_cb & cb_func;
478
559
 
479
- std::unique_ptr<llm_graph_result> res;
560
+ llm_graph_result * res;
561
+
562
+ ggml_context * ctx0 = nullptr;
563
+ ggml_cgraph * gf = nullptr;
480
564
 
481
565
  llm_graph_context(const llm_graph_params & params);
482
566
  virtual ~llm_graph_context() = default;
@@ -562,7 +646,6 @@ struct llm_graph_context {
562
646
  //
563
647
 
564
648
  ggml_tensor * build_attn_mha(
565
- ggml_cgraph * gf,
566
649
  ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
567
650
  ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
568
651
  ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -575,7 +658,6 @@ struct llm_graph_context {
575
658
 
576
659
  ggml_tensor * build_attn(
577
660
  llm_graph_input_attn_no_cache * inp,
578
- ggml_cgraph * gf,
579
661
  ggml_tensor * wo,
580
662
  ggml_tensor * wo_b,
581
663
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -590,7 +672,6 @@ struct llm_graph_context {
590
672
 
591
673
  ggml_tensor * build_attn(
592
674
  llm_graph_input_attn_kv_unified * inp,
593
- ggml_cgraph * gf,
594
675
  ggml_tensor * wo,
595
676
  ggml_tensor * wo_b,
596
677
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -606,7 +687,6 @@ struct llm_graph_context {
606
687
  // note: if k_cur or v_cur are not provided, they will not be stored in the memory
607
688
  ggml_tensor * build_attn(
608
689
  llm_graph_input_attn_kv_unified_iswa * inp,
609
- ggml_cgraph * gf,
610
690
  ggml_tensor * wo,
611
691
  ggml_tensor * wo_b,
612
692
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -621,7 +701,6 @@ struct llm_graph_context {
621
701
 
622
702
  ggml_tensor * build_attn(
623
703
  llm_graph_input_attn_cross * inp,
624
- ggml_cgraph * gf,
625
704
  ggml_tensor * wo,
626
705
  ggml_tensor * wo_b,
627
706
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -643,7 +722,6 @@ struct llm_graph_context {
643
722
  // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
644
723
  // `llama_memory_recurrent`
645
724
  ggml_tensor * build_rs(
646
- ggml_cgraph * gf,
647
725
  ggml_tensor * s,
648
726
  ggml_tensor * state_copy,
649
727
  int32_t state_size,
@@ -658,7 +736,6 @@ struct llm_graph_context {
658
736
 
659
737
  ggml_tensor * build_rs(
660
738
  llm_graph_input_rs * inp,
661
- ggml_cgraph * gf,
662
739
  ggml_tensor * s,
663
740
  int32_t state_size,
664
741
  int32_t n_seqs,
@@ -666,9 +743,8 @@ struct llm_graph_context {
666
743
 
667
744
  ggml_tensor * build_rwkv_token_shift_load(
668
745
  llm_graph_input_rs * inp,
669
- ggml_cgraph * gf,
670
746
  const llama_ubatch & ubatch,
671
- int il) const;
747
+ int il) const;
672
748
 
673
749
  ggml_tensor * build_rwkv_token_shift_store(
674
750
  ggml_tensor * token_shift,
@@ -685,7 +761,6 @@ struct llm_graph_context {
685
761
  //
686
762
 
687
763
  void build_pooling(
688
- ggml_cgraph * gf,
689
764
  ggml_tensor * cls,
690
765
  ggml_tensor * cls_b,
691
766
  ggml_tensor * cls_out,
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
65
65
  return n_embd_head_v * n_head_kv;
66
66
  }
67
67
 
68
+ bool llama_hparams::is_n_embd_k_gqa_variable() const {
69
+ const uint32_t val = n_embd_k_gqa();
70
+ for (uint32_t il = 0; il < n_layer; ++il) {
71
+ if (val != n_embd_k_gqa(il)) {
72
+ return true;
73
+ }
74
+ }
75
+
76
+ return false;
77
+ }
78
+
79
+ bool llama_hparams::is_n_embd_v_gqa_variable() const {
80
+ const uint32_t val = n_embd_v_gqa();
81
+ for (uint32_t il = 0; il < n_layer; ++il) {
82
+ if (val != n_embd_v_gqa(il)) {
83
+ return true;
84
+ }
85
+ }
86
+
87
+ return false;
88
+ }
89
+
90
+ uint32_t llama_hparams::n_embd_k_gqa_max() const {
91
+ uint32_t val = n_embd_k_gqa();
92
+ for (uint32_t il = 0; il < n_layer; ++il) {
93
+ val = std::max(val, n_embd_k_gqa(il));
94
+ }
95
+
96
+ return val;
97
+ }
98
+
99
+ uint32_t llama_hparams::n_embd_v_gqa_max() const {
100
+ uint32_t val = n_embd_v_gqa();
101
+ for (uint32_t il = 0; il < n_layer; ++il) {
102
+ val = std::max(val, n_embd_v_gqa(il));
103
+ }
104
+
105
+ return val;
106
+ }
107
+
68
108
  uint32_t llama_hparams::n_embd_r() const {
69
109
  if (wkv_head_size != 0) {
70
110
  // for RWKV models
@@ -6,7 +6,7 @@
6
6
 
7
7
  // bump if necessary
8
8
  #define LLAMA_MAX_LAYERS 512
9
- #define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
9
+ #define LLAMA_MAX_EXPERTS 384 // Kimi-K2
10
10
 
11
11
  enum llama_expert_gating_func_type {
12
12
  LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
@@ -98,7 +98,7 @@ struct llama_hparams {
98
98
  float rope_freq_scale_train;
99
99
  float rope_freq_scale_train_swa;
100
100
  uint32_t n_ctx_orig_yarn;
101
- float rope_yarn_log_mul;
101
+ float rope_yarn_log_mul = 0.0f;
102
102
 
103
103
  std::array<int, 4> rope_sections;
104
104
 
@@ -191,6 +191,14 @@ struct llama_hparams {
191
191
  // dimension of value embeddings across all k-v heads
192
192
  uint32_t n_embd_v_gqa(uint32_t il = 0) const;
193
193
 
194
+ // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
195
+ bool is_n_embd_k_gqa_variable() const;
196
+ bool is_n_embd_v_gqa_variable() const;
197
+
198
+ // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
199
+ uint32_t n_embd_k_gqa_max() const;
200
+ uint32_t n_embd_v_gqa_max() const;
201
+
194
202
  // dimension of the rolling state embeddings
195
203
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
196
204
  uint32_t n_embd_r() const;
@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
18
18
  bool v_trans,
19
19
  bool offload,
20
20
  bool swa_full,
21
+ bool unified,
21
22
  uint32_t kv_size,
22
23
  uint32_t n_seq_max,
23
24
  uint32_t n_ubatch,
24
- uint32_t n_pad) : hparams(model.hparams) {
25
+ uint32_t n_pad) : hparams(model.hparams), unified(unified) {
25
26
  llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
26
27
  llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
27
28
 
28
29
  const uint32_t size_base = kv_size;
29
30
 
30
- uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31
+ uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
31
32
 
32
33
  // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
33
34
  if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
41
42
 
42
43
  kv_base = std::make_unique<llama_kv_cache_unified>(
43
44
  model, std::move(filter_base), type_k, type_v,
44
- v_trans, offload, size_base, n_seq_max, n_pad,
45
+ v_trans, offload, unified, size_base, n_seq_max, n_pad,
45
46
  0, LLAMA_SWA_TYPE_NONE);
46
47
 
47
48
  LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
48
49
 
49
50
  kv_swa = std::make_unique<llama_kv_cache_unified>(
50
51
  model, std::move(filter_swa), type_k, type_v,
51
- v_trans, offload, size_swa, n_seq_max, n_pad,
52
+ v_trans, offload, unified, size_swa, n_seq_max, n_pad,
52
53
  hparams.n_swa, hparams.swa_type);
53
54
  }
54
55
 
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
100
101
 
101
102
  // first try simple split
102
103
  do {
104
+ if (!unified) {
105
+ // requires equal splits, so we skip the simple split
106
+ break;
107
+ }
108
+
103
109
  balloc.split_reset();
104
110
 
105
111
  std::vector<llama_ubatch> ubatches;
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
140
146
 
141
147
  std::vector<llama_ubatch> ubatches;
142
148
  while (true) {
143
- auto ubatch = balloc.split_equal(n_ubatch, false);
149
+ auto ubatch = balloc.split_equal(n_ubatch, !unified);
144
150
 
145
151
  if (ubatch.n_tokens == 0) {
146
152
  break;
@@ -20,6 +20,7 @@ public:
20
20
  bool v_trans,
21
21
  bool offload,
22
22
  bool swa_full,
23
+ bool unified,
23
24
  uint32_t kv_size,
24
25
  uint32_t n_seq_max,
25
26
  uint32_t n_ubatch,
@@ -68,6 +69,8 @@ public:
68
69
  private:
69
70
  const llama_hparams & hparams;
70
71
 
72
+ const bool unified;
73
+
71
74
  std::unique_ptr<llama_kv_cache_unified> kv_base;
72
75
  std::unique_ptr<llama_kv_cache_unified> kv_swa;
73
76
  };