@novastera-oss/llamarn 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakePresets.json +11 -0
  22. package/cpp/llama.cpp/CODEOWNERS +1 -0
  23. package/cpp/llama.cpp/README.md +4 -3
  24. package/cpp/llama.cpp/common/arg.cpp +45 -1
  25. package/cpp/llama.cpp/common/common.cpp +22 -6
  26. package/cpp/llama.cpp/common/common.h +18 -4
  27. package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
  28. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
  30. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  31. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  32. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  34. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
  35. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
  77. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
  78. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
  79. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
  109. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  115. package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
  116. package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
  117. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  118. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
  120. package/cpp/llama.cpp/include/llama.h +15 -7
  121. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  122. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  123. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  124. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  125. package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
  126. package/cpp/llama.cpp/src/llama-arch.h +5 -0
  127. package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
  128. package/cpp/llama.cpp/src/llama-batch.h +24 -18
  129. package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
  130. package/cpp/llama.cpp/src/llama-chat.h +2 -0
  131. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  132. package/cpp/llama.cpp/src/llama-context.h +26 -16
  133. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  134. package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
  135. package/cpp/llama.cpp/src/llama-graph.h +147 -72
  136. package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
  137. package/cpp/llama.cpp/src/llama-hparams.h +10 -2
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  139. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  140. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  141. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  142. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
  144. package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
  145. package/cpp/llama.cpp/src/llama-model.h +3 -4
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
  148. package/cpp/llama.cpp/src/llama-vocab.h +2 -0
  149. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  150. package/cpp/llama.cpp/src/unicode.h +2 -0
  151. package/ios/include/common.h +18 -4
  152. package/ios/include/llama.h +15 -7
  153. package/ios/libs/llama.xcframework/Info.plist +15 -15
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  155. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  158. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  165. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  172. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  174. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
  175. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  176. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  177. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  178. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  179. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  180. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
  183. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
  184. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  186. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
  187. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
  188. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  189. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  190. package/package.json +4 -4
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
28
28
  }
29
29
  }
30
30
 
31
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
+ bool res = true;
33
+
34
+ res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
+
37
+ return res;
38
+ }
39
+
31
40
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
32
41
  if (ubatch->pos && pos) {
33
42
  const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
50
59
  }
51
60
  }
52
61
 
62
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
+ bool res = true;
64
+
65
+ res &= pos->ne[0] == params.ubatch.n_tokens;
66
+
67
+ return res;
68
+ }
69
+
53
70
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
54
71
  if (ubatch->pos && attn_scale) {
55
72
  const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
71
88
  const int64_t n_tokens = ubatch->n_tokens;
72
89
 
73
90
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
74
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
91
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
75
92
 
76
93
  int32_t * data = (int32_t *) pos_bucket->data;
77
94
 
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
118
135
  }
119
136
  }
120
137
 
138
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
139
+ bool res = true;
140
+
141
+ res &= n_outputs == params.n_outputs;
142
+
143
+ return res;
144
+ }
145
+
121
146
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
122
147
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
123
148
  const int64_t n_tokens = ubatch->n_tokens;
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
287
312
  mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
288
313
  }
289
314
 
315
+ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
316
+ const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
317
+
318
+ this->mctx = mctx;
319
+
320
+ bool res = true;
321
+
322
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
323
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
324
+
325
+ res &= self_kq_mask->ne[0] == mctx->get_n_kv();
326
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
327
+
328
+ res &= mctx->get_supports_set_rows(); // TODO: tmp
329
+
330
+ return res;
331
+ }
332
+
290
333
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
291
334
  mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292
335
  mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
299
342
  mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
300
343
  }
301
344
 
345
+ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
346
+ const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
347
+
348
+ this->mctx = mctx;
349
+
350
+ bool res = true;
351
+
352
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
353
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
354
+
355
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
356
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
357
+
358
+ res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
359
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
360
+
361
+ res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
362
+ res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
363
+
364
+ res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
365
+
366
+ return res;
367
+ }
368
+
302
369
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
303
370
  GGML_ASSERT(cross_kq_mask);
304
371
 
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
306
373
  const int64_t n_tokens = ubatch->n_tokens;
307
374
 
308
375
  GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
309
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
376
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
310
377
 
311
378
  float * data = (float *) cross_kq_mask->data;
312
379
 
@@ -340,6 +407,91 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
340
407
  inp_rs->set_input(ubatch);
341
408
  }
342
409
 
410
+ //
411
+ // llm_graph_result
412
+ //
413
+
414
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
415
+ reset();
416
+
417
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
418
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
419
+ }
420
+
421
+ int64_t llm_graph_result::get_max_nodes() const {
422
+ return max_nodes;
423
+ }
424
+
425
+ void llm_graph_result::reset() {
426
+ t_tokens = nullptr;
427
+ t_logits = nullptr;
428
+ t_embd = nullptr;
429
+ t_embd_pooled = nullptr;
430
+
431
+ params = {};
432
+
433
+ inputs.clear();
434
+
435
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
436
+
437
+ ggml_init_params params = {
438
+ /*.mem_size =*/ buf_compute_meta.size(),
439
+ /*.mem_buffer =*/ buf_compute_meta.data(),
440
+ /*.no_alloc =*/ true,
441
+ };
442
+
443
+ ctx_compute.reset(ggml_init(params));
444
+
445
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
446
+ }
447
+
448
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
449
+ for (auto & input : inputs) {
450
+ input->set_input(ubatch);
451
+ }
452
+ }
453
+
454
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
455
+ if (!this->params.allow_reuse(params)) {
456
+ if (debug > 1) {
457
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
458
+ }
459
+
460
+ return false;
461
+ }
462
+
463
+ if (debug > 1) {
464
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
465
+ }
466
+
467
+ bool res = true;
468
+
469
+ for (auto & input : inputs) {
470
+ const bool cur = input->can_reuse(params);
471
+
472
+ if (debug > 1) {
473
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
474
+ }
475
+
476
+ res = res && cur;
477
+ }
478
+
479
+ if (debug > 0) {
480
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
481
+ }
482
+
483
+ return res;
484
+ }
485
+
486
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
487
+ inputs.emplace_back(std::move(input));
488
+ return inputs.back().get();
489
+ }
490
+
491
+ void llm_graph_result::set_params(const llm_graph_params & params) {
492
+ this->params = params;
493
+ }
494
+
343
495
  //
344
496
  // llm_graph_context
345
497
  //
@@ -374,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
374
526
  n_ctx_orig (cparams.n_ctx_orig_yarn),
375
527
  pooling_type (cparams.pooling_type),
376
528
  rope_type (hparams.rope_type),
377
- ctx0 (params.ctx),
378
529
  sched (params.sched),
379
530
  backend_cpu (params.backend_cpu),
380
531
  cvec (params.cvec),
@@ -382,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
382
533
  mctx (params.mctx),
383
534
  cross (params.cross),
384
535
  cb_func (params.cb),
385
- res (std::make_unique<llm_graph_result>()) {
536
+ res (params.res),
537
+ ctx0 (res->get_ctx()),
538
+ gf (res->get_gf()) {
539
+ res->set_params(params);
386
540
  }
387
541
 
388
542
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -753,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
753
907
  cb(cur, "ffn_moe_weighted", il);
754
908
  }
755
909
 
910
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
911
+
912
+ assert(n_expert_used > 0);
913
+
914
+ // order the views before the adds
915
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
916
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
917
+
918
+ ggml_build_forward_expand(gf, cur_experts[i]);
919
+ }
920
+
756
921
  // aggregate experts
757
- ggml_tensor * moe_out = nullptr;
758
- for (int i = 0; i < n_expert_used; ++i) {
759
- ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
760
- experts->nb[2], i*experts->nb[1]);
922
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
923
+ // to avoid potentially a large number of add nodes during warmup
924
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
925
+ ggml_tensor * moe_out = cur_experts[0];
761
926
 
762
- if (i == 0) {
763
- moe_out = cur_expert;
764
- } else {
765
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
766
- }
927
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
928
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
767
929
  }
768
930
 
769
- if (n_expert_used == 1) {
931
+ if (hparams.n_expert_used == 1) {
770
932
  // avoid returning a non-contiguous tensor
771
933
  moe_out = ggml_cont(ctx0, moe_out);
772
934
  }
@@ -972,7 +1134,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
972
1134
  }
973
1135
 
974
1136
  ggml_tensor * llm_graph_context::build_attn_mha(
975
- ggml_cgraph * gf,
976
1137
  ggml_tensor * q,
977
1138
  ggml_tensor * k,
978
1139
  ggml_tensor * v,
@@ -982,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
982
1143
  float kq_scale) const {
983
1144
  const bool v_trans = v->nb[1] > v->nb[2];
984
1145
 
1146
+ // split the batch into streams if needed
1147
+ const auto n_stream = k->ne[3];
1148
+
1149
+ q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
1150
+
985
1151
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
986
1152
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
987
1153
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
988
1154
 
989
- const auto n_tokens = q->ne[1];
990
- const auto n_head = q->ne[2];
991
- const auto n_kv = k->ne[1];
1155
+ const auto n_kv = k->ne[1];
992
1156
 
993
1157
  ggml_tensor * cur;
994
1158
 
@@ -1030,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1030
1194
  #endif
1031
1195
  }
1032
1196
 
1033
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1197
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1034
1198
  } else {
1035
1199
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1036
1200
 
@@ -1075,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1075
1239
 
1076
1240
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1077
1241
 
1078
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1242
+ // recombine streams
1243
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1079
1244
 
1080
1245
  if (!cparams.offload_kqv) {
1081
1246
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1102,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1102
1267
 
1103
1268
  ggml_tensor * llm_graph_context::build_attn(
1104
1269
  llm_graph_input_attn_no_cache * inp,
1105
- ggml_cgraph * gf,
1106
1270
  ggml_tensor * wo,
1107
1271
  ggml_tensor * wo_b,
1108
1272
  ggml_tensor * q_cur,
@@ -1122,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn(
1122
1286
 
1123
1287
  const auto & kq_mask = inp->get_kq_mask();
1124
1288
 
1289
+ // [TAG_NO_CACHE_PAD]
1290
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1291
+ assert(!ubatch.equal_seqs());
1292
+
1125
1293
  ggml_tensor * q = q_cur;
1126
1294
  ggml_tensor * k = k_cur;
1127
1295
  ggml_tensor * v = v_cur;
1128
1296
 
1129
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1297
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1130
1298
  cb(cur, "kqv_out", il);
1131
1299
 
1132
1300
  if (wo) {
@@ -1156,13 +1324,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
1156
1324
  {
1157
1325
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1158
1326
 
1159
- const auto n_kv = mctx_cur->get_n_kv();
1327
+ const auto n_kv = mctx_cur->get_n_kv();
1160
1328
  const auto n_tokens = ubatch.n_tokens;
1329
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1161
1330
 
1162
1331
  inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1163
1332
  inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1164
1333
 
1165
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1334
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1166
1335
  ggml_set_input(inp->self_kq_mask);
1167
1336
 
1168
1337
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1181,7 +1350,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1181
1350
 
1182
1351
  ggml_tensor * llm_graph_context::build_attn(
1183
1352
  llm_graph_input_attn_kv_unified * inp,
1184
- ggml_cgraph * gf,
1185
1353
  ggml_tensor * wo,
1186
1354
  ggml_tensor * wo_b,
1187
1355
  ggml_tensor * q_cur,
@@ -1214,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
1214
1382
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1215
1383
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1216
1384
 
1217
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1385
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1218
1386
  cb(cur, "kqv_out", il);
1219
1387
 
1220
1388
  if (wo) {
@@ -1234,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn(
1234
1402
 
1235
1403
  ggml_tensor * llm_graph_context::build_attn(
1236
1404
  llm_graph_input_attn_kv_unified_iswa * inp,
1237
- ggml_cgraph * gf,
1238
1405
  ggml_tensor * wo,
1239
1406
  ggml_tensor * wo_b,
1240
1407
  ggml_tensor * q_cur,
@@ -1281,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn(
1281
1448
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1282
1449
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1283
1450
 
1284
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1451
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1285
1452
  cb(cur, "kqv_out", il);
1286
1453
 
1287
1454
  if (wo) {
@@ -1314,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1314
1481
 
1315
1482
  ggml_tensor * llm_graph_context::build_attn(
1316
1483
  llm_graph_input_attn_cross * inp,
1317
- ggml_cgraph * gf,
1318
1484
  ggml_tensor * wo,
1319
1485
  ggml_tensor * wo_b,
1320
1486
  ggml_tensor * q_cur,
@@ -1336,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn(
1336
1502
  ggml_tensor * k = k_cur;
1337
1503
  ggml_tensor * v = v_cur;
1338
1504
 
1339
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1505
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1340
1506
  cb(cur, "kqv_out", il);
1341
1507
 
1342
1508
  if (wo) {
@@ -1362,13 +1528,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1362
1528
 
1363
1529
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1364
1530
 
1531
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1532
+
1365
1533
  {
1366
1534
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1367
1535
 
1368
1536
  inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1369
1537
  inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1370
1538
 
1371
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1539
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1372
1540
  ggml_set_input(inp->self_kq_mask);
1373
1541
 
1374
1542
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1382,7 +1550,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1382
1550
  inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1383
1551
  inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1384
1552
 
1385
- inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1553
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1386
1554
  ggml_set_input(inp->self_kq_mask_swa);
1387
1555
 
1388
1556
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1392,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1392
1560
  }
1393
1561
 
1394
1562
  ggml_tensor * llm_graph_context::build_rs(
1395
- ggml_cgraph * gf,
1396
1563
  ggml_tensor * s,
1397
1564
  ggml_tensor * state_copy,
1398
1565
  int32_t state_size,
@@ -1450,21 +1617,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1450
1617
 
1451
1618
  ggml_tensor * llm_graph_context::build_rs(
1452
1619
  llm_graph_input_rs * inp,
1453
- ggml_cgraph * gf,
1454
1620
  ggml_tensor * s,
1455
1621
  int32_t state_size,
1456
1622
  int32_t n_seqs,
1457
1623
  const llm_graph_get_rows_fn & get_state_rows) const {
1458
1624
  const auto * kv_state = inp->mctx;
1459
1625
 
1460
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1626
+ return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1461
1627
  }
1462
1628
 
1463
1629
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1464
1630
  llm_graph_input_rs * inp,
1465
- ggml_cgraph * gf,
1466
1631
  const llama_ubatch & ubatch,
1467
- int il) const {
1632
+ int il) const {
1468
1633
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1469
1634
 
1470
1635
  const auto token_shift_count = hparams.token_shift_count;
@@ -1474,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1474
1639
  ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1475
1640
 
1476
1641
  ggml_tensor * token_shift = build_rs(
1477
- inp, gf, token_shift_all,
1642
+ inp, token_shift_all,
1478
1643
  hparams.n_embd_r(), n_seqs);
1479
1644
 
1480
1645
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1514,7 +1679,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1514
1679
  }
1515
1680
 
1516
1681
  void llm_graph_context::build_pooling(
1517
- ggml_cgraph * gf,
1518
1682
  ggml_tensor * cls,
1519
1683
  ggml_tensor * cls_b,
1520
1684
  ggml_tensor * cls_out,