@fugood/llama.node 1.0.0-beta.5 → 1.0.0-beta.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/lib/binding.ts +1 -1
  2. package/package.json +14 -14
  3. package/scripts/llama.cpp.patch +27 -26
  4. package/src/LlamaCompletionWorker.cpp +21 -4
  5. package/src/LlamaCompletionWorker.h +2 -0
  6. package/src/LlamaContext.cpp +3 -12
  7. package/src/common.hpp +6 -5
  8. package/src/llama.cpp/CMakeLists.txt +15 -4
  9. package/src/llama.cpp/common/CMakeLists.txt +15 -24
  10. package/src/llama.cpp/common/arg.cpp +172 -110
  11. package/src/llama.cpp/common/chat-parser.cpp +385 -0
  12. package/src/llama.cpp/common/chat-parser.h +120 -0
  13. package/src/llama.cpp/common/chat.cpp +726 -596
  14. package/src/llama.cpp/common/chat.h +74 -8
  15. package/src/llama.cpp/common/common.cpp +56 -38
  16. package/src/llama.cpp/common/common.h +9 -3
  17. package/src/llama.cpp/common/json-partial.cpp +256 -0
  18. package/src/llama.cpp/common/json-partial.h +38 -0
  19. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  20. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
  21. package/src/llama.cpp/common/sampling.cpp +7 -8
  22. package/src/llama.cpp/common/speculative.cpp +6 -4
  23. package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
  24. package/src/llama.cpp/ggml/include/ggml.h +22 -3
  25. package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
  26. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
  27. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  28. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  29. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  30. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  31. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
  32. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  36. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  37. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  38. package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  39. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  40. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  42. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
  45. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  46. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  48. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  49. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  50. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
  51. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  52. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  53. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  54. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
  55. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  56. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
  57. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  59. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
  60. package/src/llama.cpp/include/llama.h +145 -40
  61. package/src/llama.cpp/src/CMakeLists.txt +5 -1
  62. package/src/llama.cpp/src/llama-arch.cpp +99 -3
  63. package/src/llama.cpp/src/llama-arch.h +10 -1
  64. package/src/llama.cpp/src/llama-batch.cpp +728 -272
  65. package/src/llama.cpp/src/llama-batch.h +112 -54
  66. package/src/llama.cpp/src/llama-chat.cpp +19 -2
  67. package/src/llama.cpp/src/llama-chat.h +1 -0
  68. package/src/llama.cpp/src/llama-context.cpp +525 -339
  69. package/src/llama.cpp/src/llama-context.h +38 -17
  70. package/src/llama.cpp/src/llama-cparams.cpp +4 -0
  71. package/src/llama.cpp/src/llama-cparams.h +2 -0
  72. package/src/llama.cpp/src/llama-grammar.cpp +12 -2
  73. package/src/llama.cpp/src/llama-graph.cpp +413 -353
  74. package/src/llama.cpp/src/llama-graph.h +112 -56
  75. package/src/llama.cpp/src/llama-hparams.cpp +10 -2
  76. package/src/llama.cpp/src/llama-hparams.h +13 -2
  77. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
  79. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
  80. package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
  81. package/src/llama.cpp/src/llama-kv-cells.h +415 -0
  82. package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  83. package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
  84. package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
  85. package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
  86. package/src/llama.cpp/src/llama-memory.cpp +41 -0
  87. package/src/llama.cpp/src/llama-memory.h +86 -5
  88. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  89. package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
  90. package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
  91. package/src/llama.cpp/src/llama-model.cpp +1137 -528
  92. package/src/llama.cpp/src/llama-model.h +4 -0
  93. package/src/llama.cpp/src/llama-quant.cpp +2 -1
  94. package/src/llama.cpp/src/llama-sampling.cpp +2 -2
  95. package/src/llama.cpp/src/llama-vocab.cpp +69 -32
  96. package/src/llama.cpp/src/llama-vocab.h +1 -0
  97. package/src/llama.cpp/src/llama.cpp +11 -7
  98. package/src/llama.cpp/src/unicode.cpp +5 -0
  99. package/src/tts_utils.h +1 -1
  100. package/src/llama.cpp/common/json.hpp +0 -24766
  101. package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
  102. package/src/llama.cpp/common/minja/minja.hpp +0 -2974
  103. package/src/llama.cpp/common/stb_image.h +0 -7988
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
  106. package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
  107. package/src/llama.cpp/src/llama-kv-cache.h +0 -515
  108. /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  109. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  110. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -3,7 +3,11 @@
3
3
  #include "llama-impl.h"
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
- #include "llama-kv-cache.h"
6
+
7
+ #include "llama-kv-cache-unified.h"
8
+ #include "llama-kv-cache-unified-iswa.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
7
11
 
8
12
  #include <cassert>
9
13
  #include <cmath>
@@ -83,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
83
87
 
84
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
89
  if (pos_bucket) {
86
- kv_self->set_input_pos_bucket(pos_bucket, ubatch);
90
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
87
91
  }
88
92
  }
89
93
 
90
94
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
91
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
92
- //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
95
+ GGML_ASSERT(out_ids);
93
96
 
94
- if (!out_ids) {
95
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
96
- } else {
97
- const int64_t n_tokens = ubatch->n_tokens;
97
+ const int64_t n_tokens = ubatch->n_tokens;
98
98
 
99
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
- int32_t * data = (int32_t *) out_ids->data;
99
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
+ int32_t * data = (int32_t *) out_ids->data;
101
101
 
102
- if (n_outputs == n_tokens) {
103
- for (int i = 0; i < n_tokens; ++i) {
104
- data[i] = i;
105
- }
106
- } else if (ubatch->output) {
107
- int32_t n_outputs = 0;
108
- for (int i = 0; i < n_tokens; ++i) {
109
- if (ubatch->output[i]) {
110
- data[n_outputs++] = i;
111
- }
112
- }
113
- // the graph needs to have been passed the correct number of outputs
114
- GGML_ASSERT(n_outputs == n_outputs);
115
- } else if (n_outputs == 1) {
116
- // only keep last output
117
- data[0] = n_tokens - 1;
118
- } else {
119
- GGML_ASSERT(n_outputs == 0);
120
- }
102
+ if (n_outputs == n_tokens) {
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ data[i] = i;
105
+ }
106
+
107
+ return;
108
+ }
109
+
110
+ GGML_ASSERT(ubatch->output);
111
+
112
+ int n_outputs = 0;
113
+
114
+ for (int i = 0; i < n_tokens; ++i) {
115
+ if (ubatch->output[i]) {
116
+ data[n_outputs++] = i;
121
117
  }
122
118
  }
123
119
  }
@@ -126,139 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
126
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
127
123
  const int64_t n_tokens = ubatch->n_tokens;
128
124
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
129
- const int64_t n_seqs = ubatch->n_seqs;
125
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
130
126
 
131
127
  GGML_ASSERT(mean);
132
128
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
133
129
 
134
130
  float * data = (float *) mean->data;
135
- memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
136
-
137
- std::vector<uint64_t> sum(n_tokens, 0);
131
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
138
132
 
139
- for (int s = 0; s < n_seqs; ++s) {
140
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
133
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
134
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
135
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
136
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
137
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
141
138
 
142
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
143
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
144
-
145
- sum[seq_id] += ubatch->n_seq_tokens;
139
+ sums[seq_idx] += ubatch->n_seq_tokens;
140
+ }
146
141
  }
147
142
 
148
- std::vector<float> div(n_tokens, 0.0f);
149
- for (int i = 0; i < n_tokens; ++i) {
150
- const uint64_t s = sum[i];
151
- if (s > 0) {
152
- div[i] = 1.0f/float(s);
143
+ std::vector<float> div(n_seqs_unq, 0.0f);
144
+ for (int s = 0; s < n_seqs_unq; ++s) {
145
+ const uint64_t sum = sums[s];
146
+ if (sum > 0) {
147
+ div[s] = 1.0f/float(sum);
153
148
  }
154
149
  }
155
150
 
156
- for (int s = 0; s < n_seqs; ++s) {
157
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
151
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
152
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
153
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
154
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
158
155
 
159
- for (int i = 0; i < n_seq_tokens; ++i) {
160
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
156
+ for (int j = 0; j < n_seq_tokens; ++j) {
157
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
158
+ }
161
159
  }
162
160
  }
163
161
  }
164
162
  }
165
163
 
166
164
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
167
- if (cparams.embeddings && (
168
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
169
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
170
- const int64_t n_tokens = ubatch->n_tokens;
171
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
172
- const int64_t n_seqs = ubatch->n_seqs;
165
+ const int64_t n_tokens = ubatch->n_tokens;
166
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
173
168
 
169
+ if (cparams.embeddings && (
170
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
+ )) {
174
173
  GGML_ASSERT(cls);
175
174
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
176
175
 
177
176
  uint32_t * data = (uint32_t *) cls->data;
178
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
177
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
179
178
 
180
- for (int s = 0; s < n_seqs; ++s) {
181
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
179
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
182
183
 
183
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
184
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
185
-
186
- for (int i = 0; i < n_seq_tokens; ++i) {
187
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
188
-
189
- if (pos == 0) {
190
- data[seq_id] = s*n_seq_tokens + i;
191
- }
184
+ data[seq_idx] = i;
192
185
  }
193
186
  }
194
187
  }
195
188
 
196
189
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
197
- const int64_t n_tokens = ubatch->n_tokens;
198
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
199
- const int64_t n_seqs = ubatch->n_seqs;
200
-
201
190
  GGML_ASSERT(cls);
202
191
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
203
192
 
204
193
  uint32_t * data = (uint32_t *) cls->data;
205
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
194
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
206
195
 
207
- std::vector<int> last_pos(n_tokens, -1);
208
- std::vector<int> last_row(n_tokens, -1);
196
+ std::vector<int> last_pos(n_seqs_unq, -1);
197
+ std::vector<int> last_row(n_seqs_unq, -1);
209
198
 
210
- for (int s = 0; s < n_seqs; ++s) {
211
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
212
-
213
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
214
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
199
+ for (int i = 0; i < n_tokens; ++i) {
200
+ const llama_pos pos = ubatch->pos[i];
215
201
 
216
- for (int i = 0; i < n_seq_tokens; ++i) {
217
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
202
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
218
205
 
219
- if (pos >= last_pos[seq_id]) {
220
- last_pos[seq_id] = pos;
221
- last_row[seq_id] = s*n_seq_tokens + i;
206
+ if (pos >= last_pos[seq_idx]) {
207
+ last_pos[seq_idx] = pos;
208
+ last_row[seq_idx] = i;
222
209
  }
223
210
  }
224
211
  }
225
212
 
226
- for (int i = 0; i < n_tokens; ++i) {
227
- if (last_row[i] >= 0) {
228
- data[i] = last_row[i];
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ if (last_row[s] >= 0) {
215
+ data[s] = last_row[s];
229
216
  }
230
217
  }
231
218
  }
232
219
  }
233
220
 
234
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
221
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
235
222
  GGML_UNUSED(ubatch);
236
223
 
237
- const int64_t n_kv = kv_self->n;
224
+ const int64_t n_rs = mctx->get_n_rs();
238
225
 
239
226
  if (s_copy) {
240
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
241
228
  int32_t * data = (int32_t *) s_copy->data;
242
229
 
243
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244
- for (uint32_t i = 0; i < n_kv; ++i) {
245
- data[i] = kv_self->s_copy(i);
246
- }
247
- }
248
- }
249
-
250
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251
- GGML_UNUSED(ubatch);
252
-
253
- const int64_t n_kv = kv_self->n;
254
-
255
- if (s_mask) {
256
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
257
- float * data = (float *) s_mask->data;
258
-
259
- // clear unused states
260
- for (int i = 0; i < n_kv; ++i) {
261
- data[i] = kv_self->s_mask(i);
231
+ for (uint32_t i = 0; i < n_rs; ++i) {
232
+ data[i] = mctx->s_copy(i);
262
233
  }
263
234
  }
264
235
  }
@@ -274,87 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
274
245
  }
275
246
 
276
247
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277
- if (kq_mask) {
278
- if (cparams.causal_attn) {
279
- const int64_t n_kv = ubatch->n_tokens;
280
- const int64_t n_tokens = ubatch->n_tokens;
281
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
282
- const int64_t n_seqs = ubatch->n_seqs;
283
-
284
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
285
- float * data = (float *) kq_mask->data;
286
-
287
- for (int h = 0; h < 1; ++h) {
288
- for (int s1 = 0; s1 < n_seqs; ++s1) {
289
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
290
-
291
- for (int j = 0; j < n_seq_tokens; ++j) {
292
- const int32_t tj = s1*n_seq_tokens + j;
293
-
294
- for (int s0 = 0; s0 < n_seqs; ++s0) {
295
- for (int i = 0; i < n_seq_tokens; ++i) {
296
- const int32_t ti = s0*n_seq_tokens + i;
297
- float f = -INFINITY;
298
-
299
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
300
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
301
- if (hparams.use_alibi) {
302
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
303
- } else {
304
- f = 0.0f;
305
- }
306
- break;
307
- }
308
- }
309
-
310
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
311
- }
312
- }
313
- }
314
- }
315
- }
316
- } else {
317
- const int64_t n_tokens = ubatch->n_tokens;
318
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
319
- const int64_t n_seqs = ubatch->n_seqs;
320
- const int64_t n_stride = ubatch->n_tokens;
321
-
322
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
323
-
324
- float * data = (float *) kq_mask->data;
325
-
326
- for (int h = 0; h < 1; ++h) {
327
- for (int s1 = 0; s1 < n_seqs; ++s1) {
328
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
329
-
330
- for (int j = 0; j < n_seq_tokens; ++j) {
331
- const int32_t tj = s1*n_seq_tokens + j;
332
-
333
- for (int s0 = 0; s0 < n_seqs; ++s0) {
334
- for (int i = 0; i < n_seq_tokens; ++i) {
335
- const int32_t ti = s0*n_seq_tokens + i;
336
- float f = -INFINITY;
337
-
338
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
339
- if (ubatch->seq_id[s0][s] == seq_id) {
340
- if (hparams.use_alibi) {
341
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
342
- } else {
343
- f = 0.0f;
344
- }
345
- break;
346
- }
347
- }
348
-
349
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
350
- }
351
- }
248
+ const int64_t n_kv = ubatch->n_tokens;
249
+ const int64_t n_tokens = ubatch->n_tokens;
250
+
251
+ GGML_ASSERT(kq_mask);
252
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
253
+
254
+ float * data = (float *) kq_mask->data;
352
255
 
353
- for (int i = n_tokens; i < n_stride; ++i) {
354
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
256
+ for (int h = 0; h < 1; ++h) {
257
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
258
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
259
+
260
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
261
+ float f = -INFINITY;
262
+
263
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
265
+
266
+ // TODO: reimplement this like in llama_kv_cache_unified
267
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
+ if (hparams.use_alibi) {
269
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
+ } else {
271
+ f = 0.0f;
355
272
  }
273
+ break;
356
274
  }
357
275
  }
276
+
277
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
358
278
  }
359
279
  }
360
280
  }
@@ -362,53 +282,74 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
282
 
363
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
284
  if (self_kq_mask) {
365
- kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
285
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
286
  }
367
287
  }
368
288
 
369
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
290
  if (self_kq_mask) {
371
- kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
291
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
292
  }
373
293
 
374
294
  if (self_kq_mask_swa) {
375
- kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
295
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376
296
  }
377
297
  }
378
298
 
379
299
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
380
- if (cross_kq_mask) {
381
- const int64_t n_enc = cross_kq_mask->ne[0];
382
- const int64_t n_tokens = ubatch->n_tokens;
300
+ GGML_ASSERT(cross_kq_mask);
383
301
 
384
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
385
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
302
+ const int64_t n_enc = cross_kq_mask->ne[0];
303
+ const int64_t n_tokens = ubatch->n_tokens;
386
304
 
387
- float * data = (float *) cross_kq_mask->data;
305
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
388
307
 
389
- for (int h = 0; h < 1; ++h) {
390
- for (int j = 0; j < n_tokens; ++j) {
391
- for (int i = 0; i < n_enc; ++i) {
392
- float f = -INFINITY;
393
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
394
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
395
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
396
- f = 0.0f;
397
- }
308
+ float * data = (float *) cross_kq_mask->data;
309
+
310
+ for (int h = 0; h < 1; ++h) {
311
+ for (int i = 0; i < n_tokens; ++i) {
312
+ for (int j = 0; j < n_enc; ++j) {
313
+ float f = -INFINITY;
314
+
315
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
316
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
317
+
318
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
319
+ f = 0.0f;
398
320
  }
399
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
400
321
  }
322
+
323
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
401
324
  }
325
+ }
402
326
 
403
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
404
- for (int j = 0; j < n_enc; ++j) {
405
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
406
- }
327
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
328
+ for (int j = 0; j < n_enc; ++j) {
329
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
407
330
  }
408
331
  }
409
332
  }
410
333
  }
411
334
 
335
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
+ if (self_kq_mask) {
337
+ mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
+ }
339
+
340
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
+
342
+ if (s_copy) {
343
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
+ int32_t * data = (int32_t *) s_copy->data;
345
+
346
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
+ for (uint32_t i = 0; i < n_rs; ++i) {
348
+ data[i] = mctx->get_recr()->s_copy(i);
349
+ }
350
+ }
351
+ }
352
+
412
353
  //
413
354
  // llm_graph_context
414
355
  //
@@ -448,16 +389,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448
389
  backend_cpu (params.backend_cpu),
449
390
  cvec (params.cvec),
450
391
  loras (params.loras),
451
- memory (params.memory),
392
+ mctx (params.mctx),
452
393
  cross (params.cross),
453
394
  cb_func (params.cb),
454
395
  res (std::make_unique<llm_graph_result>()) {
455
396
  }
456
397
 
457
- int64_t llm_graph_context::n_pos_per_embd() const {
458
- return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
459
- }
460
-
461
398
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
462
399
  if (cb_func) {
463
400
  cb_func(ubatch, cur, name, il);
@@ -647,6 +584,7 @@ ggml_tensor * llm_graph_context::build_ffn(
647
584
  {
648
585
  // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
649
586
  int64_t split_point = cur->ne[0] / 2;
587
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
650
588
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
651
589
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
652
590
 
@@ -656,6 +594,20 @@ ggml_tensor * llm_graph_context::build_ffn(
656
594
  cur = ggml_mul(ctx0, x0, x1);
657
595
  cb(cur, "ffn_mul", il);
658
596
  } break;
597
+ case LLM_FFN_GEGLU:
598
+ {
599
+ // Split into two equal parts
600
+ int64_t split_point = cur->ne[0] / 2;
601
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
602
+ ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
603
+ ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
604
+
605
+ x0 = ggml_gelu(ctx0, x0);
606
+ cb(x0, "ffn_gelu", il);
607
+
608
+ cur = ggml_mul(ctx0, x0, x1);
609
+ cb(cur, "ffn_geglu", il);
610
+ } break;
659
611
  }
660
612
 
661
613
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -766,9 +718,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
766
718
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
767
719
 
768
720
  if (weight_before_ffn) {
769
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
770
- ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
771
- repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
721
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
722
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
772
723
  cur = ggml_mul(ctx0, repeated, weights);
773
724
  cb(cur, "ffn_moe_weighted", il);
774
725
  }
@@ -888,11 +839,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
888
839
  }
889
840
 
890
841
  ggml_tensor * llm_graph_context::build_inp_pos() const {
891
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
842
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
892
843
 
893
844
  auto & cur = inp->pos;
894
845
 
895
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
846
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
896
847
  ggml_set_input(cur);
897
848
 
898
849
  res->add_input(std::move(inp));
@@ -915,6 +866,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
915
866
  }
916
867
 
917
868
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
869
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
870
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
871
+ // features that require constant topology such as pipline parallelism
872
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
873
+ //if (n_outputs < n_tokens) {
874
+ // return nullptr;
875
+ //}
876
+
918
877
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
919
878
 
920
879
  auto & cur = inp->out_ids;
@@ -932,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
932
891
 
933
892
  auto & cur = inp->mean;
934
893
 
935
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
894
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
936
895
  ggml_set_input(cur);
937
896
 
938
897
  res->add_input(std::move(inp));
@@ -945,41 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
945
904
 
946
905
  auto & cur = inp->cls;
947
906
 
948
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
949
- ggml_set_input(cur);
950
-
951
- res->add_input(std::move(inp));
952
-
953
- return cur;
954
- }
955
-
956
- ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
958
-
959
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960
-
961
- const auto n_kv = kv_self->n;
962
-
963
- auto & cur = inp->s_copy;
964
-
965
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
966
- ggml_set_input(cur);
967
-
968
- res->add_input(std::move(inp));
969
-
970
- return cur;
971
- }
972
-
973
- ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
975
-
976
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977
-
978
- const auto n_kv = kv_self->n;
979
-
980
- auto & cur = inp->s_mask;
981
-
982
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
907
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
983
908
  ggml_set_input(cur);
984
909
 
985
910
  res->add_input(std::move(inp));
@@ -1025,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
950
  }
1026
951
 
1027
952
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
953
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1029
954
 
1030
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
955
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1031
956
 
1032
- const auto n_kv = kv_self->get_n();
957
+ const auto n_kv = mctx_cur->get_n_kv();
1033
958
 
1034
959
  auto & cur = inp->pos_bucket;
1035
960
 
@@ -1056,6 +981,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1056
981
  return pos_bias;
1057
982
  }
1058
983
 
984
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
985
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
986
+
987
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
988
+
989
+ {
990
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
991
+
992
+ const auto n_kv = inp->mctx->get_attn()->get_n_kv();
993
+
994
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
995
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
996
+ ggml_set_input(inp->self_kq_mask);
997
+
998
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
999
+ }
1000
+
1001
+ {
1002
+ const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1003
+
1004
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1005
+ ggml_set_input(inp->s_copy);
1006
+ }
1007
+
1008
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1009
+ }
1010
+
1059
1011
  ggml_tensor * llm_graph_context::build_attn_mha(
1060
1012
  ggml_cgraph * gf,
1061
1013
  ggml_tensor * q,
@@ -1231,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
1231
1183
  }
1232
1184
 
1233
1185
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1186
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1235
1187
 
1236
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1188
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1237
1189
 
1238
1190
  {
1239
1191
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1240
1192
 
1241
- const auto n_kv = kv_self->get_n();
1193
+ const auto n_kv = mctx_cur->get_n_kv();
1242
1194
 
1243
1195
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
1196
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1268,25 +1220,29 @@ ggml_tensor * llm_graph_context::build_attn(
1268
1220
  ggml_build_forward_expand(gf, k_cur);
1269
1221
  ggml_build_forward_expand(gf, v_cur);
1270
1222
 
1271
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1223
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1272
1224
 
1273
1225
  // store to KV cache
1274
1226
  {
1275
- ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
- ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1227
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1228
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1277
1229
  }
1278
1230
 
1279
1231
  const auto & kq_mask = inp->get_kq_mask();
1280
1232
 
1281
1233
  ggml_tensor * q = q_cur;
1282
- ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
- ggml_tensor * v = kv_self->get_v(ctx0, il);
1234
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1235
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1284
1236
 
1285
1237
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
1238
  cb(cur, "kqv_out", il);
1287
1239
 
1288
1240
  if (wo) {
1289
1241
  cur = build_lora_mm(wo, cur);
1242
+ if (arch == LLM_ARCH_GLM4) {
1243
+ // GLM4 seems to have numerical issues with half-precision accumulators
1244
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1245
+ }
1290
1246
  }
1291
1247
 
1292
1248
  if (wo_b) {
@@ -1296,36 +1252,6 @@ ggml_tensor * llm_graph_context::build_attn(
1296
1252
  return cur;
1297
1253
  }
1298
1254
 
1299
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1300
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1301
-
1302
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1303
-
1304
- {
1305
- const auto n_kv = kv_self->get_kv_base()->get_n();
1306
-
1307
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1308
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1309
- ggml_set_input(inp->self_kq_mask);
1310
-
1311
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1312
- }
1313
-
1314
- {
1315
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1316
-
1317
- const auto n_kv = kv_self->get_kv_swa()->get_n();
1318
-
1319
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1320
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1321
- ggml_set_input(inp->self_kq_mask_swa);
1322
-
1323
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1324
- }
1325
-
1326
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1327
- }
1328
-
1329
1255
  ggml_tensor * llm_graph_context::build_attn(
1330
1256
  llm_graph_input_attn_kv_unified_iswa * inp,
1331
1257
  ggml_cgraph * gf,
@@ -1344,33 +1270,29 @@ ggml_tensor * llm_graph_context::build_attn(
1344
1270
  ggml_build_forward_expand(gf, k_cur);
1345
1271
  ggml_build_forward_expand(gf, v_cur);
1346
1272
 
1347
- const bool is_swa = hparams.is_swa(il);
1273
+ const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1348
1274
 
1349
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1275
+ const bool is_swa = hparams.is_swa(il);
1350
1276
 
1351
- const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1277
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1352
1278
 
1353
1279
  // store to KV cache
1354
1280
  {
1355
- ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1356
- ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1281
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1282
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1357
1283
  }
1358
1284
 
1359
1285
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1360
1286
 
1361
1287
  ggml_tensor * q = q_cur;
1362
- ggml_tensor * k = kv->get_k(ctx0, il);
1363
- ggml_tensor * v = kv->get_v(ctx0, il);
1288
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1289
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1364
1290
 
1365
1291
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1366
1292
  cb(cur, "kqv_out", il);
1367
1293
 
1368
1294
  if (wo) {
1369
1295
  cur = build_lora_mm(wo, cur);
1370
- if (arch == LLM_ARCH_GLM4) {
1371
- // GLM4 seems to have numerical issues with half-precision accumulators
1372
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1373
- }
1374
1296
  }
1375
1297
 
1376
1298
  if (wo_b) {
@@ -1439,56 +1361,182 @@ ggml_tensor * llm_graph_context::build_attn(
1439
1361
  return cur;
1440
1362
  }
1441
1363
 
1442
- ggml_tensor * llm_graph_context::build_copy_mask_state(
1443
- ggml_cgraph * gf,
1444
- ggml_tensor * s,
1445
- ggml_tensor * state_copy,
1446
- ggml_tensor * state_mask,
1447
- int32_t n_state,
1448
- int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1364
+ ggml_tensor * llm_graph_context::build_attn(
1365
+ llm_graph_input_mem_hybrid * inp,
1366
+ ggml_cgraph * gf,
1367
+ ggml_tensor * wo,
1368
+ ggml_tensor * wo_b,
1369
+ ggml_tensor * q_cur,
1370
+ ggml_tensor * k_cur,
1371
+ ggml_tensor * v_cur,
1372
+ ggml_tensor * kq_b,
1373
+ ggml_tensor * v_mla,
1374
+ float kq_scale,
1375
+ int il) const {
1376
+ // these nodes are added to the graph together so that they are not reordered
1377
+ // by doing so, the number of splits in the graph is reduced
1378
+ ggml_build_forward_expand(gf, q_cur);
1379
+ ggml_build_forward_expand(gf, k_cur);
1380
+ ggml_build_forward_expand(gf, v_cur);
1381
+
1382
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1450
1383
 
1451
- const auto n_kv = kv_self->n;
1452
- const auto kv_head = kv_self->head;
1384
+ // store to KV cache
1385
+ {
1386
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1387
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1388
+ }
1453
1389
 
1454
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1390
+ const auto & kq_mask = inp->get_kq_mask();
1455
1391
 
1456
- // copy states
1457
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458
- // this shrinks the tensors's ne[1] to n_kv
1459
- states = ggml_get_rows(ctx0, states, state_copy);
1392
+ ggml_tensor * q = q_cur;
1393
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1394
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1460
1395
 
1461
- // clear states of sequences which are starting at the beginning of this batch
1462
- // FIXME: zero-out NANs?
1463
- states = ggml_mul(ctx0, states, state_mask);
1396
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1397
+ cb(cur, "kqv_out", il);
1464
1398
 
1465
- // copy states which won't be changed further (between n_seqs and n_kv)
1399
+ if (wo) {
1400
+ cur = build_lora_mm(wo, cur);
1401
+ if (arch == LLM_ARCH_GLM4) {
1402
+ // GLM4 seems to have numerical issues with half-precision accumulators
1403
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1404
+ }
1405
+ }
1406
+
1407
+ if (wo_b) {
1408
+ cur = ggml_add(ctx0, cur, wo_b);
1409
+ }
1410
+
1411
+ return cur;
1412
+ }
1413
+
1414
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1415
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1416
+
1417
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1418
+
1419
+ {
1420
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1421
+
1422
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1423
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1424
+ ggml_set_input(inp->self_kq_mask);
1425
+
1426
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1427
+ }
1428
+
1429
+ {
1430
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1431
+
1432
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1433
+
1434
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1435
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1436
+ ggml_set_input(inp->self_kq_mask_swa);
1437
+
1438
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1439
+ }
1440
+
1441
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1442
+ }
1443
+
1444
+ ggml_tensor * llm_graph_context::build_rs(
1445
+ ggml_cgraph * gf,
1446
+ ggml_tensor * s,
1447
+ ggml_tensor * state_copy,
1448
+ int32_t state_size,
1449
+ int32_t n_seqs,
1450
+ uint32_t n_kv,
1451
+ uint32_t kv_head,
1452
+ uint32_t kv_size,
1453
+ int32_t rs_zero,
1454
+ bool avoid_copies) const {
1455
+
1456
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1457
+
1458
+ // Clear a single state which will then be copied to the other cleared states.
1459
+ // Note that this is a no-op when the view is zero-sized.
1460
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1461
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1462
+
1463
+ ggml_tensor * output_states;
1464
+
1465
+ if (!avoid_copies) {
1466
+ // copy states
1467
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1468
+ // {state_size, kv_size} -> {state_size, n_seqs}
1469
+ output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1470
+ ggml_build_forward_expand(gf, output_states);
1471
+ } else {
1472
+ // FIXME: make the gathering operation happen before the copy below
1473
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1474
+ output_states = states;
1475
+ }
1476
+
1477
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1478
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1466
1479
  ggml_build_forward_expand(gf,
1467
1480
  ggml_cpy(ctx0,
1468
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1469
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1481
+ states_extra,
1482
+ ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1483
+
1484
+ return output_states;
1485
+ }
1486
+
1487
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1488
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1489
+
1490
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1491
+
1492
+ const auto n_rs = mctx_cur->get_n_rs();
1493
+
1494
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1495
+ ggml_set_input(inp->s_copy);
1496
+
1497
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1498
+ }
1499
+
1500
+ ggml_tensor * llm_graph_context::build_rs(
1501
+ llm_graph_input_rs * inp,
1502
+ ggml_cgraph * gf,
1503
+ ggml_tensor * s,
1504
+ int32_t state_size,
1505
+ int32_t n_seqs,
1506
+ bool avoid_copies) const {
1507
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1508
+
1509
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1510
+ }
1511
+
1512
+ ggml_tensor * llm_graph_context::build_rs(
1513
+ llm_graph_input_mem_hybrid * inp,
1514
+ ggml_cgraph * gf,
1515
+ ggml_tensor * s,
1516
+ int32_t state_size,
1517
+ int32_t n_seqs,
1518
+ bool avoid_copies) const {
1519
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1470
1520
 
1471
- // the part of the states that will be used and modified
1472
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1521
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1473
1522
  }
1474
1523
 
1475
1524
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1476
- ggml_cgraph * gf,
1477
- ggml_tensor * state_copy,
1478
- ggml_tensor * state_mask,
1479
- const llama_ubatch & ubatch,
1525
+ llm_graph_input_rs * inp,
1526
+ ggml_cgraph * gf,
1527
+ const llama_ubatch & ubatch,
1480
1528
  int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1529
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1482
1530
 
1483
1531
  const auto token_shift_count = hparams.token_shift_count;
1484
1532
 
1485
1533
  const int64_t n_seqs = ubatch.n_seqs;
1486
1534
 
1487
- ggml_tensor * token_shift_all = kv_self->k_l[il];
1535
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1488
1536
 
1489
- ggml_tensor * token_shift = build_copy_mask_state(
1490
- gf, token_shift_all, state_copy, state_mask,
1491
- hparams.n_embd_k_s(), n_seqs);
1537
+ ggml_tensor * token_shift = build_rs(
1538
+ inp, gf, token_shift_all,
1539
+ hparams.n_embd_r(), n_seqs);
1492
1540
 
1493
1541
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1494
1542
 
@@ -1499,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1547
  ggml_tensor * token_shift,
1500
1548
  const llama_ubatch & ubatch,
1501
1549
  int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1550
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1503
1551
 
1504
1552
  const auto token_shift_count = hparams.token_shift_count;
1505
1553
  const auto n_embd = hparams.n_embd;
1506
1554
 
1507
1555
  const int64_t n_seqs = ubatch.n_seqs;
1508
1556
 
1509
- const auto kv_head = kv_self->head;
1557
+ const auto kv_head = mctx_cur->get_head();
1510
1558
 
1511
1559
  return ggml_cpy(
1512
1560
  ctx0,
1513
1561
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1514
- ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
1562
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1515
1563
  );
1516
1564
  }
1517
1565
 
@@ -1562,20 +1610,32 @@ void llm_graph_context::build_pooling(
1562
1610
  ggml_tensor * inp_cls = build_inp_cls();
1563
1611
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1564
1612
 
1565
- // classification head
1566
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1567
- GGML_ASSERT(cls != nullptr);
1568
- GGML_ASSERT(cls_b != nullptr);
1569
-
1570
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1571
- cur = ggml_tanh(ctx0, cur);
1572
-
1573
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1574
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1575
- if (cls_out) {
1576
- GGML_ASSERT(cls_out_b != nullptr);
1577
-
1578
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1613
+ if (cls) {
1614
+ // classification head
1615
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1616
+ cur = ggml_mul_mat(ctx0, cls, inp);
1617
+ if (cls_b) {
1618
+ cur = ggml_add(ctx0, cur, cls_b);
1619
+ }
1620
+ cur = ggml_tanh(ctx0, cur);
1621
+
1622
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1623
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1624
+ if (cls_out) {
1625
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1626
+ if (cls_out_b) {
1627
+ cur = ggml_add(ctx0, cur, cls_out_b);
1628
+ }
1629
+ }
1630
+ } else if (cls_out) {
1631
+ // Single layer classification head (direct projection)
1632
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1633
+ cur = ggml_mul_mat(ctx0, cls_out, inp);
1634
+ if (cls_out_b) {
1635
+ cur = ggml_add(ctx0, cur, cls_out_b);
1636
+ }
1637
+ } else {
1638
+ GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1579
1639
  }
1580
1640
  } break;
1581
1641
  default: