@fugood/llama.node 0.3.17 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -7,6 +7,7 @@
7
7
  #include "llama-adapter.h"
8
8
 
9
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
10
11
 
11
12
  #include <map>
12
13
  #include <vector>
@@ -27,7 +28,12 @@ struct llama_context {
27
28
 
28
29
  void synchronize();
29
30
 
30
- const llama_model & get_model() const;
31
+ const llama_model & get_model() const;
32
+ const llama_cparams & get_cparams() const;
33
+
34
+ ggml_backend_sched_t get_sched() const;
35
+
36
+ ggml_context * get_ctx_compute() const;
31
37
 
32
38
  uint32_t n_ctx() const;
33
39
  uint32_t n_ctx_per_seq() const;
@@ -128,6 +134,32 @@ struct llama_context {
128
134
  llama_perf_context_data perf_get_data() const;
129
135
  void perf_reset();
130
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ ggml_opt_dataset_t dataset,
145
+ ggml_opt_result_t result_train,
146
+ ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ ggml_opt_epoch_callback callback_train,
149
+ ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ ggml_opt_dataset_t dataset,
153
+ ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
131
163
  private:
132
164
  //
133
165
  // output
@@ -137,49 +169,30 @@ private:
137
169
  // Returns max number of outputs for which space was reserved.
138
170
  int32_t output_reserve(int32_t n_outputs);
139
171
 
140
- // make the outputs have the same order they had in the user-provided batch
141
- // TODO: maybe remove this
142
- void output_reorder();
143
-
144
172
  //
145
173
  // graph
146
174
  //
147
175
 
176
+ public:
148
177
  int32_t graph_max_nodes() const;
149
178
 
150
179
  // zero-out inputs and create the ctx_compute for the compute graph
151
180
  ggml_cgraph * graph_init();
152
181
 
182
+ // returns the result of ggml_backend_sched_graph_compute_async execution
183
+ ggml_status graph_compute(
184
+ ggml_cgraph * gf,
185
+ bool batched);
186
+
187
+ private:
153
188
  llm_graph_result_ptr graph_build(
154
189
  ggml_context * ctx,
155
190
  ggml_cgraph * gf,
156
191
  const llama_ubatch & ubatch,
157
192
  llm_graph_type gtype);
158
193
 
159
- // returns the result of ggml_backend_sched_graph_compute_async execution
160
- ggml_status graph_compute(
161
- ggml_cgraph * gf,
162
- bool batched);
163
-
164
194
  llm_graph_cb graph_get_cb() const;
165
195
 
166
- // used by kv_self_update()
167
- ggml_tensor * build_rope_shift(
168
- ggml_context * ctx0,
169
- ggml_tensor * cur,
170
- ggml_tensor * shift,
171
- ggml_tensor * factors,
172
- float freq_base,
173
- float freq_scale) const;
174
-
175
- llm_graph_result_ptr build_kv_self_shift(
176
- ggml_context * ctx0,
177
- ggml_cgraph * gf) const;
178
-
179
- llm_graph_result_ptr build_kv_self_defrag(
180
- ggml_context * ctx0,
181
- ggml_cgraph * gf) const;
182
-
183
196
  // TODO: read/write lora adapters and cvec
184
197
  size_t state_write_data(llama_io_write_i & io);
185
198
  size_t state_read_data (llama_io_read_i & io);
@@ -196,14 +209,10 @@ private:
196
209
  llama_cparams cparams;
197
210
  llama_adapter_cvec cvec;
198
211
  llama_adapter_loras loras;
199
- llama_sbatch sbatch;
200
212
 
201
213
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
202
214
 
203
- std::unique_ptr<llama_kv_cache_unified> kv_self;
204
-
205
- // TODO: remove
206
- bool logits_all = false;
215
+ std::unique_ptr<llama_memory_i> memory;
207
216
 
208
217
  // decode output (2-dimensional array: [n_outputs][n_vocab])
209
218
  size_t logits_size = 0; // capacity (of floats) for logits
@@ -230,6 +239,9 @@ private:
230
239
 
231
240
  ggml_context_ptr ctx_compute;
232
241
 
242
+ // training
243
+ ggml_opt_context_t opt_ctx = nullptr;
244
+
233
245
  ggml_threadpool_t threadpool = nullptr;
234
246
  ggml_threadpool_t threadpool_batch = nullptr;
235
247
 
@@ -30,6 +30,7 @@ struct llama_cparams {
30
30
  bool flash_attn;
31
31
  bool no_perf;
32
32
  bool warmup;
33
+ bool op_offload;
33
34
 
34
35
  enum llama_pooling_type pooling_type;
35
36
 
@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
284
284
 
285
285
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
286
286
  for (uint32_t i = 0; i < n_kv; ++i) {
287
- const uint32_t cell_id = i + kv_self->head;
288
-
289
- //////////////////////////////////////////////
290
- // TODO: this should not mutate the KV cache !
291
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
292
-
293
- // prevent out-of-bound sources
294
- if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
295
- kv_cell.src = cell_id;
296
- }
297
-
298
- data[i] = kv_cell.src;
299
-
300
- // TODO: do not mutate the KV cache
301
- // ensure copy only happens once
302
- if (kv_cell.src != (int32_t) cell_id) {
303
- kv_cell.src = cell_id;
304
- }
287
+ data[i] = kv_self->s_copy(i);
305
288
  }
306
289
  }
307
290
  }
@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
317
300
 
318
301
  // clear unused states
319
302
  for (int i = 0; i < n_kv; ++i) {
320
- const uint32_t cell_id = i + kv_self->head;
321
-
322
- //////////////////////////////////////////////
323
- // TODO: this should not mutate the KV cache !
324
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
325
-
326
- data[i] = (float) (kv_cell.src >= 0);
327
-
328
- // only clear once
329
- if (kv_cell.src < 0) {
330
- kv_cell.src = cell_id;
331
- }
303
+ data[i] = kv_self->s_mask(i);
332
304
  }
333
305
  }
334
306
  }
@@ -810,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn(
810
782
  } break;
811
783
  }
812
784
 
813
- if (type_gate == LLM_FFN_PAR) {
785
+ if (gate && type_gate == LLM_FFN_PAR) {
814
786
  cur = ggml_mul(ctx0, cur, tmp);
815
787
  cb(cur, "ffn_gate_par", il);
816
788
  }
@@ -999,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
999
971
  inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1000
972
  //cb(inp->tokens, "inp_tokens", -1);
1001
973
  ggml_set_input(inp->tokens);
974
+ res->t_tokens = inp->tokens;
1002
975
 
1003
976
  cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1004
977
 
@@ -1105,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
1105
1078
  }
1106
1079
 
1107
1080
  ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1108
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1081
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1109
1082
 
1110
1083
  auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
1111
1084
 
@@ -1122,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1122
1095
  }
1123
1096
 
1124
1097
  ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1125
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1098
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1126
1099
 
1127
1100
  auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
1128
1101
 
@@ -1255,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1255
1228
  ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1256
1229
 
1257
1230
  if (v_mla) {
1231
+ #if 0
1232
+ // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1233
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1258
1234
  cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1259
1235
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1236
+ #else
1237
+ // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1238
+ // The permutations are noops and only change how the tensor data is interpreted.
1239
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1240
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1241
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1242
+ cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1243
+ #endif
1260
1244
  }
1261
1245
 
1262
1246
  cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
@@ -1436,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
1436
1420
 
1437
1421
  // store to KV cache
1438
1422
  {
1439
- GGML_ASSERT(!kv_self->recurrent);
1440
-
1441
1423
  const auto kv_head = kv_self->head;
1442
1424
 
1443
1425
  GGML_ASSERT(kv_self->size == n_ctx);
@@ -1587,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
1587
1569
  ggml_tensor * state_mask,
1588
1570
  int32_t n_state,
1589
1571
  int32_t n_seqs) const {
1590
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1572
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1591
1573
 
1592
1574
  const auto n_kv = kv_self->n;
1593
1575
  const auto kv_head = kv_self->head;
@@ -1619,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1619
1601
  ggml_tensor * state_mask,
1620
1602
  const llama_ubatch & ubatch,
1621
1603
  int il) const {
1622
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1604
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1623
1605
 
1624
1606
  const auto token_shift_count = hparams.token_shift_count;
1625
1607
 
@@ -1640,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1640
1622
  ggml_tensor * token_shift,
1641
1623
  const llama_ubatch & ubatch,
1642
1624
  int il) const {
1643
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1625
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1644
1626
 
1645
1627
  const auto token_shift_count = hparams.token_shift_count;
1646
1628
  const auto n_embd = hparams.n_embd;
@@ -19,6 +19,7 @@ struct llama_cparams;
19
19
 
20
20
  class llama_memory_i;
21
21
  class llama_kv_cache_unified;
22
+ class llama_kv_cache_recurrent;
22
23
 
23
24
  // certain models (typically multi-modal) can produce different types of graphs
24
25
  enum llm_graph_type {
@@ -186,26 +187,26 @@ public:
186
187
 
187
188
  class llm_graph_input_s_copy : public llm_graph_input_i {
188
189
  public:
189
- llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
190
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
190
191
  virtual ~llm_graph_input_s_copy() = default;
191
192
 
192
193
  void set_input(const llama_ubatch * ubatch) override;
193
194
 
194
195
  ggml_tensor * s_copy; // I32 [kv_size]
195
196
 
196
- const llama_kv_cache_unified * kv_self;
197
+ const llama_kv_cache_recurrent * kv_self;
197
198
  };
198
199
 
199
200
  class llm_graph_input_s_mask : public llm_graph_input_i {
200
201
  public:
201
- llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
202
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
202
203
  virtual ~llm_graph_input_s_mask() = default;
203
204
 
204
205
  void set_input(const llama_ubatch * ubatch) override;
205
206
 
206
207
  ggml_tensor * s_mask; // F32 [1, n_kv]
207
208
 
208
- const llama_kv_cache_unified * kv_self;
209
+ const llama_kv_cache_recurrent * kv_self;
209
210
  };
210
211
 
211
212
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -297,6 +298,7 @@ class llm_graph_result_i {
297
298
  public:
298
299
  virtual ~llm_graph_result_i() = default;
299
300
 
301
+ virtual ggml_tensor * get_tokens() = 0;
300
302
  virtual ggml_tensor * get_logits() = 0;
301
303
  virtual ggml_tensor * get_embd() = 0;
302
304
  virtual ggml_tensor * get_embd_pooled() = 0;
@@ -311,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
311
313
  public:
312
314
  virtual ~llm_graph_result() = default;
313
315
 
316
+ ggml_tensor * get_tokens() override { return t_tokens; }
314
317
  ggml_tensor * get_logits() override { return t_logits; }
315
318
  ggml_tensor * get_embd() override { return t_embd; }
316
319
  ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
@@ -327,6 +330,7 @@ public:
327
330
  }
328
331
 
329
332
  // important graph nodes
333
+ ggml_tensor * t_tokens = nullptr;
330
334
  ggml_tensor * t_logits = nullptr;
331
335
  ggml_tensor * t_embd = nullptr;
332
336
  ggml_tensor * t_embd_pooled = nullptr;
@@ -350,8 +354,8 @@ struct llm_graph_params {
350
354
  const llama_cparams & cparams;
351
355
  const llama_ubatch & ubatch;
352
356
 
353
- ggml_backend_sched * sched;
354
- ggml_backend * backend_cpu;
357
+ ggml_backend_sched_t sched;
358
+ ggml_backend_t backend_cpu;
355
359
 
356
360
  const llama_adapter_cvec * cvec;
357
361
  const llama_adapter_loras * loras;
@@ -402,9 +406,9 @@ struct llm_graph_context {
402
406
 
403
407
  ggml_context * ctx0 = nullptr;
404
408
 
405
- ggml_backend_sched * sched;
409
+ ggml_backend_sched_t sched;
406
410
 
407
- ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
411
+ ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
408
412
 
409
413
  const llama_adapter_cvec * cvec;
410
414
  const llama_adapter_loras * loras;