@fugood/llama.node 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  4. package/src/llama.cpp/common/arg.cpp +44 -0
  5. package/src/llama.cpp/common/common.cpp +22 -6
  6. package/src/llama.cpp/common/common.h +15 -1
  7. package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
  8. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  9. package/src/llama.cpp/ggml/include/ggml.h +104 -10
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  12. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
  19. package/src/llama.cpp/include/llama.h +13 -47
  20. package/src/llama.cpp/src/llama-arch.cpp +298 -3
  21. package/src/llama.cpp/src/llama-arch.h +22 -1
  22. package/src/llama.cpp/src/llama-batch.cpp +103 -71
  23. package/src/llama.cpp/src/llama-batch.h +31 -18
  24. package/src/llama.cpp/src/llama-chat.cpp +59 -1
  25. package/src/llama.cpp/src/llama-chat.h +3 -0
  26. package/src/llama.cpp/src/llama-context.cpp +134 -95
  27. package/src/llama.cpp/src/llama-context.h +13 -16
  28. package/src/llama.cpp/src/llama-cparams.h +3 -2
  29. package/src/llama.cpp/src/llama-graph.cpp +279 -180
  30. package/src/llama.cpp/src/llama-graph.h +183 -122
  31. package/src/llama.cpp/src/llama-hparams.cpp +47 -1
  32. package/src/llama.cpp/src/llama-hparams.h +12 -1
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  34. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  35. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  36. package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  37. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  40. package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
  41. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  42. package/src/llama.cpp/src/llama-memory.h +3 -0
  43. package/src/llama.cpp/src/llama-model.cpp +3373 -743
  44. package/src/llama.cpp/src/llama-model.h +20 -4
  45. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  46. package/src/llama.cpp/src/llama-vocab.cpp +376 -10
  47. package/src/llama.cpp/src/llama-vocab.h +43 -0
  48. package/src/llama.cpp/src/unicode.cpp +207 -0
  49. package/src/llama.cpp/src/unicode.h +2 -0
  50. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -24,8 +24,6 @@ public:
24
24
  // this callback is used to filter out layers that should not be included in the cache
25
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
26
 
27
- using ubatch_heads = std::vector<uint32_t>;
28
-
29
27
  struct defrag_info {
30
28
  bool empty() const {
31
29
  return ids.empty();
@@ -37,6 +35,63 @@ public:
37
35
  std::vector<uint32_t> ids;
38
36
  };
39
37
 
38
+ struct stream_copy_info {
39
+ bool empty() const {
40
+ assert(ssrc.size() == sdst.size());
41
+ return ssrc.empty();
42
+ }
43
+
44
+ std::vector<uint32_t> ssrc;
45
+ std::vector<uint32_t> sdst;
46
+ };
47
+
48
+ // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
49
+ // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
50
+ struct slot_info {
51
+ // data for ggml_set_rows
52
+ using idx_vec_t = std::vector<uint32_t>;
53
+
54
+ // number of streams: ns = s1 - s0 + 1
55
+ llama_seq_id s0;
56
+ llama_seq_id s1;
57
+
58
+ std::vector<llama_seq_id> strm; // [ns]
59
+ std::vector<idx_vec_t> idxs; // [ns]
60
+
61
+ uint32_t head() const {
62
+ GGML_ASSERT(idxs.size() == 1);
63
+ GGML_ASSERT(!idxs[0].empty());
64
+
65
+ return idxs[0][0];
66
+ }
67
+
68
+ void resize(size_t n) {
69
+ strm.resize(n);
70
+ idxs.resize(n);
71
+ }
72
+
73
+ size_t size() const {
74
+ GGML_ASSERT(idxs.size() == strm.size());
75
+ GGML_ASSERT(!idxs.empty());
76
+
77
+ return idxs[0].size();
78
+ }
79
+
80
+ size_t n_stream() const {
81
+ return strm.size();
82
+ }
83
+
84
+ bool empty() const {
85
+ return idxs.empty();
86
+ }
87
+
88
+ void clear() {
89
+ idxs.clear();
90
+ }
91
+ };
92
+
93
+ using slot_info_vec_t = std::vector<slot_info>;
94
+
40
95
  llama_kv_cache_unified(
41
96
  const llama_model & model,
42
97
  layer_filter_cb && filter,
@@ -44,6 +99,7 @@ public:
44
99
  ggml_type type_v,
45
100
  bool v_trans,
46
101
  bool offload,
102
+ bool unified,
47
103
  uint32_t kv_size,
48
104
  uint32_t n_seq_max,
49
105
  uint32_t n_pad,
@@ -87,7 +143,8 @@ public:
87
143
  // llama_kv_cache_unified specific API
88
144
  //
89
145
 
90
- uint32_t get_size() const;
146
+ uint32_t get_size() const;
147
+ uint32_t get_n_stream() const;
91
148
 
92
149
  bool get_has_shift() const;
93
150
 
@@ -97,37 +154,48 @@ public:
97
154
 
98
155
  uint32_t get_n_kv() const;
99
156
 
157
+ // TODO: temporary
158
+ bool get_supports_set_rows() const;
159
+
100
160
  // get views of the current state of the cache
101
- ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
102
- ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
161
+ ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
162
+ ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
103
163
 
104
164
  // store k_cur and v_cur in the cache based on the provided head location
105
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
106
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
165
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
166
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
107
167
 
108
168
  //
109
169
  // preparation API
110
170
  //
111
171
 
112
- // find places for the provided ubatches in the cache, returns the head locations
172
+ // find places for the provided ubatches in the cache, returns the slot infos
113
173
  // return empty vector on failure
114
- ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
174
+ slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
115
175
 
116
- bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
176
+ bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
117
177
 
118
- // return the cell position where we can insert the ubatch
119
- // return -1 on failure to find a contiguous slot of kv cells
120
- int32_t find_slot(const llama_ubatch & ubatch) const;
178
+ // find a slot of kv cells that can hold the ubatch
179
+ // if cont == true, then the slot must be continuous
180
+ // return empty slot_info on failure
181
+ slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
121
182
 
122
- // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
123
- void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
183
+ // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
184
+ void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
124
185
 
125
186
  //
126
- // set_input API
187
+ // input API
127
188
  //
128
189
 
190
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
191
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
192
+
193
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
194
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
195
+
196
+ void set_input_k_shift(ggml_tensor * dst) const;
197
+
129
198
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
130
- void set_input_k_shift (ggml_tensor * dst) const;
131
199
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
132
200
 
133
201
  private:
@@ -141,15 +209,15 @@ private:
141
209
 
142
210
  ggml_tensor * k;
143
211
  ggml_tensor * v;
212
+
213
+ std::vector<ggml_tensor *> k_stream;
214
+ std::vector<ggml_tensor *> v_stream;
144
215
  };
145
216
 
146
217
  bool v_trans = true; // the value tensor is transposed
147
218
 
148
- // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
149
- // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
150
- uint32_t head = 0;
151
-
152
219
  const uint32_t n_seq_max = 1;
220
+ const uint32_t n_stream = 1;
153
221
 
154
222
  // required padding
155
223
  const uint32_t n_pad = 1;
@@ -157,14 +225,29 @@ private:
157
225
  // SWA
158
226
  const uint32_t n_swa = 0;
159
227
 
228
+ // env: LLAMA_KV_CACHE_DEBUG
160
229
  int debug = 0;
161
230
 
231
+ // env: LLAMA_SET_ROWS (temporary)
232
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14285
233
+ bool supports_set_rows = false;
234
+
162
235
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
163
236
 
164
237
  std::vector<ggml_context_ptr> ctxs;
165
238
  std::vector<ggml_backend_buffer_ptr> bufs;
166
239
 
167
- llama_kv_cells_unified cells;
240
+ // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
241
+ // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
242
+ std::vector<uint32_t> v_heads;
243
+
244
+ std::vector<llama_kv_cells_unified> v_cells;
245
+
246
+ // maps from a sequence id to a stream id
247
+ std::vector<uint32_t> seq_to_stream;
248
+
249
+ // pending stream copies that will be applied during the next update
250
+ stream_copy_info sc_info;
168
251
 
169
252
  std::vector<kv_layer> layers;
170
253
 
@@ -190,29 +273,34 @@ private:
190
273
  float freq_base,
191
274
  float freq_scale) const;
192
275
 
193
- llm_graph_result_ptr build_graph_shift(
194
- const llama_cparams & cparams,
195
- ggml_context * ctx,
196
- ggml_cgraph * gf) const;
276
+ ggml_cgraph * build_graph_shift(
277
+ llm_graph_result * res,
278
+ llama_context * lctx) const;
197
279
 
198
- llm_graph_result_ptr build_graph_defrag(
199
- const llama_cparams & cparams,
200
- ggml_context * ctx,
201
- ggml_cgraph * gf,
280
+ ggml_cgraph * build_graph_defrag(
281
+ llm_graph_result * res,
282
+ llama_context * lctx,
202
283
  const defrag_info & dinfo) const;
203
284
 
204
- void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
205
- void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
285
+ struct cell_ranges_t {
286
+ uint32_t strm;
287
+
288
+ std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
289
+ };
290
+
291
+ void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
292
+ void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
206
293
 
207
- bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
208
- bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
294
+ bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
295
+ bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
209
296
  };
210
297
 
211
298
  class llama_kv_cache_unified_context : public llama_memory_context_i {
212
299
  public:
213
300
  // some shorthands
214
- using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
- using defrag_info = llama_kv_cache_unified::defrag_info;
301
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
302
+ using defrag_info = llama_kv_cache_unified::defrag_info;
303
+ using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
216
304
 
217
305
  // used for errors
218
306
  llama_kv_cache_unified_context(llama_memory_status status);
@@ -226,12 +314,13 @@ public:
226
314
  llama_kv_cache_unified * kv,
227
315
  llama_context * lctx,
228
316
  bool do_shift,
229
- defrag_info dinfo);
317
+ defrag_info dinfo,
318
+ stream_copy_info sc_info);
230
319
 
231
320
  // used to create a batch procesing context from a batch
232
321
  llama_kv_cache_unified_context(
233
322
  llama_kv_cache_unified * kv,
234
- ubatch_heads heads,
323
+ slot_info_vec_t sinfos,
235
324
  std::vector<llama_ubatch> ubatches);
236
325
 
237
326
  virtual ~llama_kv_cache_unified_context();
@@ -252,16 +341,24 @@ public:
252
341
 
253
342
  uint32_t get_n_kv() const;
254
343
 
344
+ // TODO: temporary
345
+ bool get_supports_set_rows() const;
346
+
255
347
  // get views of the current state of the cache
256
348
  ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
257
349
  ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
258
350
 
259
351
  // store k_cur and v_cur in the cache based on the provided head location
260
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
261
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
352
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
353
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
262
354
 
263
- void set_input_k_shift(ggml_tensor * dst) const;
355
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
356
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
357
+
358
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
359
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
264
360
 
361
+ void set_input_k_shift (ggml_tensor * dst) const;
265
362
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
266
363
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
267
364
 
@@ -279,14 +376,16 @@ private:
279
376
 
280
377
  defrag_info dinfo;
281
378
 
379
+ stream_copy_info sc_info;
380
+
282
381
  //
283
382
  // batch processing context
284
383
  //
285
384
 
286
- // the index of the next ubatch to process
287
- size_t i_next = 0;
385
+ // the index of the cur ubatch to process
386
+ size_t i_cur = 0;
288
387
 
289
- ubatch_heads heads;
388
+ slot_info_vec_t sinfos;
290
389
 
291
390
  std::vector<llama_ubatch> ubatches;
292
391
 
@@ -297,7 +396,4 @@ private:
297
396
  // a heuristic, to avoid attending the full cache if it is not yet utilized
298
397
  // as the cache gets filled, the benefit from this heuristic disappears
299
398
  int32_t n_kv;
300
-
301
- // the beginning of the current slot in which the ubatch will be inserted
302
- int32_t head;
303
399
  };
@@ -105,10 +105,30 @@ public:
105
105
  res.resize(n);
106
106
 
107
107
  for (uint32_t j = 0; j < n; ++j) {
108
- res.pos[j] = pos[i + j];
109
- res.seq[j] = seq[i + j];
108
+ const auto idx = i + j;
110
109
 
111
- assert(shift[i + j] == 0);
110
+ res.pos[j] = pos[idx];
111
+ res.seq[j] = seq[idx];
112
+
113
+ assert(shift[idx] == 0);
114
+ }
115
+
116
+ return res;
117
+ }
118
+
119
+ // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
120
+ llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
121
+ llama_kv_cells_unified res;
122
+
123
+ res.resize(idxs.size());
124
+
125
+ for (uint32_t j = 0; j < idxs.size(); ++j) {
126
+ const auto idx = idxs[j];
127
+
128
+ res.pos[j] = pos[idx];
129
+ res.seq[j] = seq[idx];
130
+
131
+ assert(shift[idx] == 0);
112
132
  }
113
133
 
114
134
  return res;
@@ -119,26 +139,58 @@ public:
119
139
  assert(i + other.pos.size() <= pos.size());
120
140
 
121
141
  for (uint32_t j = 0; j < other.pos.size(); ++j) {
122
- if (pos[i + j] == -1 && other.pos[j] != -1) {
142
+ const auto idx = i + j;
143
+
144
+ if (pos[idx] == -1 && other.pos[j] != -1) {
123
145
  used.insert(i + j);
124
146
  }
125
147
 
126
- if (pos[i + j] != -1 && other.pos[j] == -1) {
148
+ if (pos[idx] != -1 && other.pos[j] == -1) {
127
149
  used.erase(i + j);
128
150
  }
129
151
 
130
- if (pos[i + j] != -1) {
152
+ if (pos[idx] != -1) {
131
153
  seq_pos_rm(i + j);
132
154
  }
133
155
 
134
- pos[i + j] = other.pos[j];
135
- seq[i + j] = other.seq[j];
156
+ pos[idx] = other.pos[j];
157
+ seq[idx] = other.seq[j];
136
158
 
137
- if (pos[i + j] != -1) {
159
+ if (pos[idx] != -1) {
138
160
  seq_pos_add(i + j);
139
161
  }
140
162
 
141
- assert(shift[i + j] == 0);
163
+ assert(shift[idx] == 0);
164
+ }
165
+ }
166
+
167
+ // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
168
+ void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
169
+ assert(idxs.size() == other.pos.size());
170
+
171
+ for (uint32_t j = 0; j < other.pos.size(); ++j) {
172
+ const auto idx = idxs[j];
173
+
174
+ if (pos[idx] == -1 && other.pos[j] != -1) {
175
+ used.insert(idx);
176
+ }
177
+
178
+ if (pos[idx] != -1 && other.pos[j] == -1) {
179
+ used.erase(idx);
180
+ }
181
+
182
+ if (pos[idx] != -1) {
183
+ seq_pos_rm(idx);
184
+ }
185
+
186
+ pos[idx] = other.pos[j];
187
+ seq[idx] = other.seq[j];
188
+
189
+ if (pos[idx] != -1) {
190
+ seq_pos_add(idx);
191
+ }
192
+
193
+ assert(shift[idx] == 0);
142
194
  }
143
195
  }
144
196
 
@@ -38,6 +38,7 @@ llama_memory_hybrid::llama_memory_hybrid(
38
38
  type_v,
39
39
  v_trans,
40
40
  offload,
41
+ 1,
41
42
  kv_size,
42
43
  n_seq_max,
43
44
  n_pad,
@@ -70,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
70
71
  // if all tokens are output, split by sequence
71
72
  ubatch = balloc.split_seq(n_ubatch);
72
73
  } else {
73
- ubatch = balloc.split_equal(n_ubatch);
74
+ ubatch = balloc.split_equal(n_ubatch, false);
74
75
  }
75
76
 
76
77
  if (ubatch.n_tokens == 0) {
@@ -80,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
80
81
  ubatches.push_back(std::move(ubatch)); // NOLINT
81
82
  }
82
83
 
84
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
85
+ // failed to find a suitable split
86
+ break;
87
+ }
88
+
83
89
  // prepare the recurrent batches first
84
90
  if (!mem_recr->prepare(ubatches)) {
85
91
  // TODO: will the recurrent cache be in an undefined context at this point?
@@ -195,11 +201,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
195
201
 
196
202
  llama_memory_hybrid_context::llama_memory_hybrid_context(
197
203
  llama_memory_hybrid * mem,
198
- std::vector<uint32_t> heads_attn,
204
+ slot_info_vec_t sinfos_attn,
199
205
  std::vector<llama_ubatch> ubatches) :
200
206
  ubatches(std::move(ubatches)),
201
207
  // note: here we copy the ubatches. not sure if this is ideal
202
- ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
208
+ ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
203
209
  ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
204
210
  status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
205
211
  }
@@ -218,7 +224,7 @@ bool llama_memory_hybrid_context::next() {
218
224
  }
219
225
 
220
226
  bool llama_memory_hybrid_context::apply() {
221
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
227
+ assert(!llama_memory_status_is_fail(status));
222
228
 
223
229
  bool res = true;
224
230
 
@@ -92,6 +92,8 @@ private:
92
92
 
93
93
  class llama_memory_hybrid_context : public llama_memory_context_i {
94
94
  public:
95
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
96
+
95
97
  // init failure
96
98
  explicit llama_memory_hybrid_context(llama_memory_status status);
97
99
 
@@ -107,7 +109,7 @@ public:
107
109
  // init success
108
110
  llama_memory_hybrid_context(
109
111
  llama_memory_hybrid * mem,
110
- std::vector<uint32_t> heads_attn,
112
+ slot_info_vec_t sinfos_attn,
111
113
  std::vector<llama_ubatch> ubatches);
112
114
 
113
115
  ~llama_memory_hybrid_context() = default;
@@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent(
25
25
  uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
26
26
  const int32_t n_layer = hparams.n_layer;
27
27
 
28
- LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
29
- __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
30
-
31
28
  head = 0;
32
29
  size = mem_size;
33
30
  used = 0;
@@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent(
84
81
 
85
82
  ggml_context * ctx = ctx_for_buft(buft);
86
83
  if (!ctx) {
87
- throw std::runtime_error("failed to create ggml context for kv cache");
84
+ throw std::runtime_error("failed to create ggml context for rs cache");
88
85
  }
89
86
 
90
87
  ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
@@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent(
102
99
 
103
100
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
104
101
  if (!buf) {
105
- throw std::runtime_error("failed to allocate buffer for kv cache");
102
+ throw std::runtime_error("failed to allocate buffer for rs cache");
106
103
  }
107
104
  ggml_backend_buffer_clear(buf, 0);
108
- LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
105
+ LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
109
106
  bufs.emplace_back(buf);
110
107
  }
111
108
 
@@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent(
113
110
  const size_t memory_size_r = size_r_bytes();
114
111
  const size_t memory_size_s = size_s_bytes();
115
112
 
116
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
117
- (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
113
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
114
+ (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
118
115
  ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
119
116
  ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
120
117
  }
@@ -374,7 +371,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
374
371
  // if all tokens are output, split by sequence
375
372
  ubatch = balloc.split_seq(n_ubatch);
376
373
  } else {
377
- ubatch = balloc.split_equal(n_ubatch);
374
+ ubatch = balloc.split_equal(n_ubatch, false);
378
375
  }
379
376
 
380
377
  if (ubatch.n_tokens == 0) {
@@ -384,6 +381,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
384
381
  ubatches.push_back(std::move(ubatch)); // NOLINT
385
382
  }
386
383
 
384
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
385
+ // failed to find a suitable split
386
+ break;
387
+ }
388
+
387
389
  if (!prepare(ubatches)) {
388
390
  break;
389
391
  }
@@ -444,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
444
446
  // A slot should be always be contiguous.
445
447
 
446
448
  // can only process batches with an equal number of new tokens in each sequence
447
- GGML_ASSERT(ubatch.equal_seqs);
449
+ GGML_ASSERT(ubatch.equal_seqs());
448
450
 
449
451
  int32_t min = size - 1;
450
452
  int32_t max = 0;
@@ -1071,7 +1073,15 @@ bool llama_memory_recurrent_context::next() {
1071
1073
  }
1072
1074
 
1073
1075
  bool llama_memory_recurrent_context::apply() {
1074
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1076
+ assert(!llama_memory_status_is_fail(status));
1077
+
1078
+ // no ubatches -> this is an update
1079
+ if (ubatches.empty()) {
1080
+ // recurrent cache never performs updates
1081
+ assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
1082
+
1083
+ return true;
1084
+ }
1075
1085
 
1076
1086
  mem->find_slot(ubatches[i_next]);
1077
1087
 
@@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
40
40
  // if either status has an update, then the combined status has an update
41
41
  return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
42
42
  }
43
+
44
+ bool llama_memory_status_is_fail(llama_memory_status status) {
45
+ switch (status) {
46
+ case LLAMA_MEMORY_STATUS_SUCCESS:
47
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
48
+ {
49
+ return false;
50
+ }
51
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
52
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
53
+ {
54
+ return true;
55
+ }
56
+ }
57
+
58
+ return false;
59
+ }
@@ -31,6 +31,9 @@ enum llama_memory_status {
31
31
  // useful for implementing hybrid memory types (e.g. iSWA)
32
32
  llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
33
33
 
34
+ // helper function for checking if a memory status indicates a failure
35
+ bool llama_memory_status_is_fail(llama_memory_status status);
36
+
34
37
  // the interface for managing the memory context during batch processing
35
38
  // this interface is implemented per memory type. see:
36
39
  // - llama_kv_cache_unified_context