@fugood/llama.node 1.0.0-beta.5 → 1.0.0-beta.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +3 -1
- package/lib/index.js +2 -0
- package/lib/index.ts +3 -1
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +27 -26
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +28 -7
- package/src/LlamaCompletionWorker.h +4 -0
- package/src/LlamaContext.cpp +14 -17
- package/src/common.hpp +7 -6
- package/src/llama.cpp/CMakeLists.txt +15 -4
- package/src/llama.cpp/common/CMakeLists.txt +15 -24
- package/src/llama.cpp/common/arg.cpp +172 -110
- package/src/llama.cpp/common/chat-parser.cpp +385 -0
- package/src/llama.cpp/common/chat-parser.h +120 -0
- package/src/llama.cpp/common/chat.cpp +726 -596
- package/src/llama.cpp/common/chat.h +74 -8
- package/src/llama.cpp/common/common.cpp +56 -38
- package/src/llama.cpp/common/common.h +9 -3
- package/src/llama.cpp/common/json-partial.cpp +256 -0
- package/src/llama.cpp/common/json-partial.h +38 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/src/llama.cpp/common/sampling.cpp +7 -8
- package/src/llama.cpp/common/speculative.cpp +6 -4
- package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
- package/src/llama.cpp/ggml/include/ggml.h +22 -3
- package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
- package/src/llama.cpp/include/llama.h +145 -40
- package/src/llama.cpp/src/CMakeLists.txt +5 -1
- package/src/llama.cpp/src/llama-arch.cpp +99 -3
- package/src/llama.cpp/src/llama-arch.h +10 -1
- package/src/llama.cpp/src/llama-batch.cpp +728 -272
- package/src/llama.cpp/src/llama-batch.h +112 -54
- package/src/llama.cpp/src/llama-chat.cpp +19 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +525 -339
- package/src/llama.cpp/src/llama-context.h +38 -17
- package/src/llama.cpp/src/llama-cparams.cpp +4 -0
- package/src/llama.cpp/src/llama-cparams.h +2 -0
- package/src/llama.cpp/src/llama-grammar.cpp +12 -2
- package/src/llama.cpp/src/llama-graph.cpp +413 -353
- package/src/llama.cpp/src/llama-graph.h +112 -56
- package/src/llama.cpp/src/llama-hparams.cpp +10 -2
- package/src/llama.cpp/src/llama-hparams.h +13 -2
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
- package/src/llama.cpp/src/llama-kv-cells.h +415 -0
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
- package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
- package/src/llama.cpp/src/llama-memory.cpp +41 -0
- package/src/llama.cpp/src/llama-memory.h +86 -5
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/src/llama.cpp/src/llama-model.cpp +1137 -528
- package/src/llama.cpp/src/llama-model.h +4 -0
- package/src/llama.cpp/src/llama-quant.cpp +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2 -2
- package/src/llama.cpp/src/llama-vocab.cpp +69 -32
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/llama.cpp +11 -7
- package/src/llama.cpp/src/unicode.cpp +5 -0
- package/src/tts_utils.h +1 -1
- package/src/llama.cpp/common/json.hpp +0 -24766
- package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
- package/src/llama.cpp/common/minja/minja.hpp +0 -2974
- package/src/llama.cpp/common/stb_image.h +0 -7988
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
- package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
- package/src/llama.cpp/src/llama-kv-cache.h +0 -515
- /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -3,7 +3,11 @@
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
4
|
#include "llama-batch.h"
|
|
5
5
|
#include "llama-cparams.h"
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
#include "llama-kv-cache-unified.h"
|
|
8
|
+
#include "llama-kv-cache-unified-iswa.h"
|
|
9
|
+
#include "llama-memory-hybrid.h"
|
|
10
|
+
#include "llama-memory-recurrent.h"
|
|
7
11
|
|
|
8
12
|
#include <cassert>
|
|
9
13
|
#include <cmath>
|
|
@@ -83,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
83
87
|
|
|
84
88
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
|
85
89
|
if (pos_bucket) {
|
|
86
|
-
|
|
90
|
+
mctx->set_input_pos_bucket(pos_bucket, ubatch);
|
|
87
91
|
}
|
|
88
92
|
}
|
|
89
93
|
|
|
90
94
|
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
|
91
|
-
|
|
92
|
-
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
|
|
95
|
+
GGML_ASSERT(out_ids);
|
|
93
96
|
|
|
94
|
-
|
|
95
|
-
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
|
|
96
|
-
} else {
|
|
97
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
97
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
98
98
|
|
|
99
|
-
|
|
100
|
-
|
|
99
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
|
|
100
|
+
int32_t * data = (int32_t *) out_ids->data;
|
|
101
101
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
data[0] = n_tokens - 1;
|
|
118
|
-
} else {
|
|
119
|
-
GGML_ASSERT(n_outputs == 0);
|
|
120
|
-
}
|
|
102
|
+
if (n_outputs == n_tokens) {
|
|
103
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
104
|
+
data[i] = i;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
GGML_ASSERT(ubatch->output);
|
|
111
|
+
|
|
112
|
+
int n_outputs = 0;
|
|
113
|
+
|
|
114
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
115
|
+
if (ubatch->output[i]) {
|
|
116
|
+
data[n_outputs++] = i;
|
|
121
117
|
}
|
|
122
118
|
}
|
|
123
119
|
}
|
|
@@ -126,139 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
|
126
122
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
|
127
123
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
128
124
|
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
129
|
-
const int64_t
|
|
125
|
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
130
126
|
|
|
131
127
|
GGML_ASSERT(mean);
|
|
132
128
|
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
|
|
133
129
|
|
|
134
130
|
float * data = (float *) mean->data;
|
|
135
|
-
memset(mean->data, 0, n_tokens
|
|
136
|
-
|
|
137
|
-
std::vector<uint64_t> sum(n_tokens, 0);
|
|
131
|
+
memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
|
|
138
132
|
|
|
139
|
-
|
|
140
|
-
|
|
133
|
+
std::vector<uint64_t> sums(n_seqs_unq, 0);
|
|
134
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
|
135
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
136
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
137
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
141
138
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
sum[seq_id] += ubatch->n_seq_tokens;
|
|
139
|
+
sums[seq_idx] += ubatch->n_seq_tokens;
|
|
140
|
+
}
|
|
146
141
|
}
|
|
147
142
|
|
|
148
|
-
std::vector<float> div(
|
|
149
|
-
for (int
|
|
150
|
-
const uint64_t
|
|
151
|
-
if (
|
|
152
|
-
div[
|
|
143
|
+
std::vector<float> div(n_seqs_unq, 0.0f);
|
|
144
|
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
|
145
|
+
const uint64_t sum = sums[s];
|
|
146
|
+
if (sum > 0) {
|
|
147
|
+
div[s] = 1.0f/float(sum);
|
|
153
148
|
}
|
|
154
149
|
}
|
|
155
150
|
|
|
156
|
-
for (int
|
|
157
|
-
|
|
151
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
|
152
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
153
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
154
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
158
155
|
|
|
159
|
-
|
|
160
|
-
|
|
156
|
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
157
|
+
data[seq_idx*n_tokens + i + j] = div[seq_idx];
|
|
158
|
+
}
|
|
161
159
|
}
|
|
162
160
|
}
|
|
163
161
|
}
|
|
164
162
|
}
|
|
165
163
|
|
|
166
164
|
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
171
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
172
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
165
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
166
|
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
167
|
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
173
168
|
|
|
169
|
+
if (cparams.embeddings && (
|
|
170
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
|
171
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
|
|
172
|
+
)) {
|
|
174
173
|
GGML_ASSERT(cls);
|
|
175
174
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
176
175
|
|
|
177
176
|
uint32_t * data = (uint32_t *) cls->data;
|
|
178
|
-
memset(cls->data, 0,
|
|
177
|
+
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
179
178
|
|
|
180
|
-
for (int
|
|
181
|
-
|
|
179
|
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
|
180
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
181
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
182
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
182
183
|
|
|
183
|
-
|
|
184
|
-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
|
185
|
-
|
|
186
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
|
187
|
-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
|
|
188
|
-
|
|
189
|
-
if (pos == 0) {
|
|
190
|
-
data[seq_id] = s*n_seq_tokens + i;
|
|
191
|
-
}
|
|
184
|
+
data[seq_idx] = i;
|
|
192
185
|
}
|
|
193
186
|
}
|
|
194
187
|
}
|
|
195
188
|
|
|
196
189
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
|
197
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
198
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
199
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
200
|
-
|
|
201
190
|
GGML_ASSERT(cls);
|
|
202
191
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
203
192
|
|
|
204
193
|
uint32_t * data = (uint32_t *) cls->data;
|
|
205
|
-
memset(cls->data, 0,
|
|
194
|
+
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
206
195
|
|
|
207
|
-
std::vector<int> last_pos(
|
|
208
|
-
std::vector<int> last_row(
|
|
196
|
+
std::vector<int> last_pos(n_seqs_unq, -1);
|
|
197
|
+
std::vector<int> last_row(n_seqs_unq, -1);
|
|
209
198
|
|
|
210
|
-
for (int
|
|
211
|
-
const
|
|
212
|
-
|
|
213
|
-
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
|
214
|
-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
|
|
199
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
200
|
+
const llama_pos pos = ubatch->pos[i];
|
|
215
201
|
|
|
216
|
-
for (int
|
|
217
|
-
const
|
|
202
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
203
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
204
|
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
218
205
|
|
|
219
|
-
if (pos >= last_pos[
|
|
220
|
-
last_pos[
|
|
221
|
-
last_row[
|
|
206
|
+
if (pos >= last_pos[seq_idx]) {
|
|
207
|
+
last_pos[seq_idx] = pos;
|
|
208
|
+
last_row[seq_idx] = i;
|
|
222
209
|
}
|
|
223
210
|
}
|
|
224
211
|
}
|
|
225
212
|
|
|
226
|
-
for (int
|
|
227
|
-
if (last_row[
|
|
228
|
-
data[
|
|
213
|
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
|
214
|
+
if (last_row[s] >= 0) {
|
|
215
|
+
data[s] = last_row[s];
|
|
229
216
|
}
|
|
230
217
|
}
|
|
231
218
|
}
|
|
232
219
|
}
|
|
233
220
|
|
|
234
|
-
void
|
|
221
|
+
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
235
222
|
GGML_UNUSED(ubatch);
|
|
236
223
|
|
|
237
|
-
const int64_t
|
|
224
|
+
const int64_t n_rs = mctx->get_n_rs();
|
|
238
225
|
|
|
239
226
|
if (s_copy) {
|
|
240
227
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
241
228
|
int32_t * data = (int32_t *) s_copy->data;
|
|
242
229
|
|
|
243
230
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
244
|
-
for (uint32_t i = 0; i <
|
|
245
|
-
data[i] =
|
|
246
|
-
}
|
|
247
|
-
}
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
251
|
-
GGML_UNUSED(ubatch);
|
|
252
|
-
|
|
253
|
-
const int64_t n_kv = kv_self->n;
|
|
254
|
-
|
|
255
|
-
if (s_mask) {
|
|
256
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
|
257
|
-
float * data = (float *) s_mask->data;
|
|
258
|
-
|
|
259
|
-
// clear unused states
|
|
260
|
-
for (int i = 0; i < n_kv; ++i) {
|
|
261
|
-
data[i] = kv_self->s_mask(i);
|
|
231
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
232
|
+
data[i] = mctx->s_copy(i);
|
|
262
233
|
}
|
|
263
234
|
}
|
|
264
235
|
}
|
|
@@ -274,87 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
274
245
|
}
|
|
275
246
|
|
|
276
247
|
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
285
|
-
float * data = (float *) kq_mask->data;
|
|
286
|
-
|
|
287
|
-
for (int h = 0; h < 1; ++h) {
|
|
288
|
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
|
289
|
-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
|
290
|
-
|
|
291
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
292
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
|
293
|
-
|
|
294
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
|
295
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
|
296
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
|
297
|
-
float f = -INFINITY;
|
|
298
|
-
|
|
299
|
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
|
300
|
-
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
|
301
|
-
if (hparams.use_alibi) {
|
|
302
|
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
|
303
|
-
} else {
|
|
304
|
-
f = 0.0f;
|
|
305
|
-
}
|
|
306
|
-
break;
|
|
307
|
-
}
|
|
308
|
-
}
|
|
309
|
-
|
|
310
|
-
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
|
311
|
-
}
|
|
312
|
-
}
|
|
313
|
-
}
|
|
314
|
-
}
|
|
315
|
-
}
|
|
316
|
-
} else {
|
|
317
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
318
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
319
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
320
|
-
const int64_t n_stride = ubatch->n_tokens;
|
|
321
|
-
|
|
322
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
323
|
-
|
|
324
|
-
float * data = (float *) kq_mask->data;
|
|
325
|
-
|
|
326
|
-
for (int h = 0; h < 1; ++h) {
|
|
327
|
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
|
328
|
-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
|
329
|
-
|
|
330
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
331
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
|
332
|
-
|
|
333
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
|
334
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
|
335
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
|
336
|
-
float f = -INFINITY;
|
|
337
|
-
|
|
338
|
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
|
339
|
-
if (ubatch->seq_id[s0][s] == seq_id) {
|
|
340
|
-
if (hparams.use_alibi) {
|
|
341
|
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
|
342
|
-
} else {
|
|
343
|
-
f = 0.0f;
|
|
344
|
-
}
|
|
345
|
-
break;
|
|
346
|
-
}
|
|
347
|
-
}
|
|
348
|
-
|
|
349
|
-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
|
350
|
-
}
|
|
351
|
-
}
|
|
248
|
+
const int64_t n_kv = ubatch->n_tokens;
|
|
249
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
250
|
+
|
|
251
|
+
GGML_ASSERT(kq_mask);
|
|
252
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
|
253
|
+
|
|
254
|
+
float * data = (float *) kq_mask->data;
|
|
352
255
|
|
|
353
|
-
|
|
354
|
-
|
|
256
|
+
for (int h = 0; h < 1; ++h) {
|
|
257
|
+
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
258
|
+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
259
|
+
|
|
260
|
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
261
|
+
float f = -INFINITY;
|
|
262
|
+
|
|
263
|
+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
|
264
|
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
265
|
+
|
|
266
|
+
// TODO: reimplement this like in llama_kv_cache_unified
|
|
267
|
+
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
|
268
|
+
if (hparams.use_alibi) {
|
|
269
|
+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
|
270
|
+
} else {
|
|
271
|
+
f = 0.0f;
|
|
355
272
|
}
|
|
273
|
+
break;
|
|
356
274
|
}
|
|
357
275
|
}
|
|
276
|
+
|
|
277
|
+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
|
358
278
|
}
|
|
359
279
|
}
|
|
360
280
|
}
|
|
@@ -362,53 +282,74 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
362
282
|
|
|
363
283
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
364
284
|
if (self_kq_mask) {
|
|
365
|
-
|
|
285
|
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
366
286
|
}
|
|
367
287
|
}
|
|
368
288
|
|
|
369
289
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
|
370
290
|
if (self_kq_mask) {
|
|
371
|
-
|
|
291
|
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
372
292
|
}
|
|
373
293
|
|
|
374
294
|
if (self_kq_mask_swa) {
|
|
375
|
-
|
|
295
|
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
376
296
|
}
|
|
377
297
|
}
|
|
378
298
|
|
|
379
299
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
380
|
-
|
|
381
|
-
const int64_t n_enc = cross_kq_mask->ne[0];
|
|
382
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
300
|
+
GGML_ASSERT(cross_kq_mask);
|
|
383
301
|
|
|
384
|
-
|
|
385
|
-
|
|
302
|
+
const int64_t n_enc = cross_kq_mask->ne[0];
|
|
303
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
386
304
|
|
|
387
|
-
|
|
305
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
|
306
|
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
|
388
307
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
308
|
+
float * data = (float *) cross_kq_mask->data;
|
|
309
|
+
|
|
310
|
+
for (int h = 0; h < 1; ++h) {
|
|
311
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
312
|
+
for (int j = 0; j < n_enc; ++j) {
|
|
313
|
+
float f = -INFINITY;
|
|
314
|
+
|
|
315
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
316
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
317
|
+
|
|
318
|
+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
|
319
|
+
f = 0.0f;
|
|
398
320
|
}
|
|
399
|
-
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
|
400
321
|
}
|
|
322
|
+
|
|
323
|
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
|
401
324
|
}
|
|
325
|
+
}
|
|
402
326
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
}
|
|
327
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
328
|
+
for (int j = 0; j < n_enc; ++j) {
|
|
329
|
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
407
330
|
}
|
|
408
331
|
}
|
|
409
332
|
}
|
|
410
333
|
}
|
|
411
334
|
|
|
335
|
+
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
336
|
+
if (self_kq_mask) {
|
|
337
|
+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
|
341
|
+
|
|
342
|
+
if (s_copy) {
|
|
343
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
344
|
+
int32_t * data = (int32_t *) s_copy->data;
|
|
345
|
+
|
|
346
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
347
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
348
|
+
data[i] = mctx->get_recr()->s_copy(i);
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
|
|
412
353
|
//
|
|
413
354
|
// llm_graph_context
|
|
414
355
|
//
|
|
@@ -448,16 +389,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
448
389
|
backend_cpu (params.backend_cpu),
|
|
449
390
|
cvec (params.cvec),
|
|
450
391
|
loras (params.loras),
|
|
451
|
-
|
|
392
|
+
mctx (params.mctx),
|
|
452
393
|
cross (params.cross),
|
|
453
394
|
cb_func (params.cb),
|
|
454
395
|
res (std::make_unique<llm_graph_result>()) {
|
|
455
396
|
}
|
|
456
397
|
|
|
457
|
-
int64_t llm_graph_context::n_pos_per_embd() const {
|
|
458
|
-
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
|
459
|
-
}
|
|
460
|
-
|
|
461
398
|
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
462
399
|
if (cb_func) {
|
|
463
400
|
cb_func(ubatch, cur, name, il);
|
|
@@ -647,6 +584,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
647
584
|
{
|
|
648
585
|
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
|
649
586
|
int64_t split_point = cur->ne[0] / 2;
|
|
587
|
+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
650
588
|
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
651
589
|
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
652
590
|
|
|
@@ -656,6 +594,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
656
594
|
cur = ggml_mul(ctx0, x0, x1);
|
|
657
595
|
cb(cur, "ffn_mul", il);
|
|
658
596
|
} break;
|
|
597
|
+
case LLM_FFN_GEGLU:
|
|
598
|
+
{
|
|
599
|
+
// Split into two equal parts
|
|
600
|
+
int64_t split_point = cur->ne[0] / 2;
|
|
601
|
+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
602
|
+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
603
|
+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
604
|
+
|
|
605
|
+
x0 = ggml_gelu(ctx0, x0);
|
|
606
|
+
cb(x0, "ffn_gelu", il);
|
|
607
|
+
|
|
608
|
+
cur = ggml_mul(ctx0, x0, x1);
|
|
609
|
+
cb(cur, "ffn_geglu", il);
|
|
610
|
+
} break;
|
|
659
611
|
}
|
|
660
612
|
|
|
661
613
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -766,9 +718,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
766
718
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
|
767
719
|
|
|
768
720
|
if (weight_before_ffn) {
|
|
769
|
-
//
|
|
770
|
-
ggml_tensor * repeated =
|
|
771
|
-
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
|
721
|
+
// repeat cur to [n_embd, n_expert_used, n_tokens]
|
|
722
|
+
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
|
|
772
723
|
cur = ggml_mul(ctx0, repeated, weights);
|
|
773
724
|
cb(cur, "ffn_moe_weighted", il);
|
|
774
725
|
}
|
|
@@ -888,11 +839,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
888
839
|
}
|
|
889
840
|
|
|
890
841
|
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
|
891
|
-
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
|
842
|
+
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
|
|
892
843
|
|
|
893
844
|
auto & cur = inp->pos;
|
|
894
845
|
|
|
895
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
|
846
|
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
|
|
896
847
|
ggml_set_input(cur);
|
|
897
848
|
|
|
898
849
|
res->add_input(std::move(inp));
|
|
@@ -915,6 +866,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
|
915
866
|
}
|
|
916
867
|
|
|
917
868
|
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
|
869
|
+
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
|
|
870
|
+
// but this would make the graph topology depend on the number of output tokens, which can interere with
|
|
871
|
+
// features that require constant topology such as pipline parallelism
|
|
872
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
|
|
873
|
+
//if (n_outputs < n_tokens) {
|
|
874
|
+
// return nullptr;
|
|
875
|
+
//}
|
|
876
|
+
|
|
918
877
|
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
|
919
878
|
|
|
920
879
|
auto & cur = inp->out_ids;
|
|
@@ -932,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|
|
932
891
|
|
|
933
892
|
auto & cur = inp->mean;
|
|
934
893
|
|
|
935
|
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens,
|
|
894
|
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
|
|
936
895
|
ggml_set_input(cur);
|
|
937
896
|
|
|
938
897
|
res->add_input(std::move(inp));
|
|
@@ -945,41 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
|
945
904
|
|
|
946
905
|
auto & cur = inp->cls;
|
|
947
906
|
|
|
948
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32,
|
|
949
|
-
ggml_set_input(cur);
|
|
950
|
-
|
|
951
|
-
res->add_input(std::move(inp));
|
|
952
|
-
|
|
953
|
-
return cur;
|
|
954
|
-
}
|
|
955
|
-
|
|
956
|
-
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
957
|
-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
958
|
-
|
|
959
|
-
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
|
960
|
-
|
|
961
|
-
const auto n_kv = kv_self->n;
|
|
962
|
-
|
|
963
|
-
auto & cur = inp->s_copy;
|
|
964
|
-
|
|
965
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
|
966
|
-
ggml_set_input(cur);
|
|
967
|
-
|
|
968
|
-
res->add_input(std::move(inp));
|
|
969
|
-
|
|
970
|
-
return cur;
|
|
971
|
-
}
|
|
972
|
-
|
|
973
|
-
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
|
974
|
-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
975
|
-
|
|
976
|
-
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
|
977
|
-
|
|
978
|
-
const auto n_kv = kv_self->n;
|
|
979
|
-
|
|
980
|
-
auto & cur = inp->s_mask;
|
|
981
|
-
|
|
982
|
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
|
907
|
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
|
|
983
908
|
ggml_set_input(cur);
|
|
984
909
|
|
|
985
910
|
res->add_input(std::move(inp));
|
|
@@ -1025,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
|
1025
950
|
}
|
|
1026
951
|
|
|
1027
952
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
1028
|
-
const
|
|
953
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
|
1029
954
|
|
|
1030
|
-
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams,
|
|
955
|
+
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
|
1031
956
|
|
|
1032
|
-
const auto n_kv =
|
|
957
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
|
1033
958
|
|
|
1034
959
|
auto & cur = inp->pos_bucket;
|
|
1035
960
|
|
|
@@ -1056,6 +981,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
|
1056
981
|
return pos_bias;
|
|
1057
982
|
}
|
|
1058
983
|
|
|
984
|
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
985
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
986
|
+
|
|
987
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
|
988
|
+
|
|
989
|
+
{
|
|
990
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
|
991
|
+
|
|
992
|
+
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
|
993
|
+
|
|
994
|
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
995
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
996
|
+
ggml_set_input(inp->self_kq_mask);
|
|
997
|
+
|
|
998
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
{
|
|
1002
|
+
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
|
1003
|
+
|
|
1004
|
+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1005
|
+
ggml_set_input(inp->s_copy);
|
|
1006
|
+
}
|
|
1007
|
+
|
|
1008
|
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
1009
|
+
}
|
|
1010
|
+
|
|
1059
1011
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1060
1012
|
ggml_cgraph * gf,
|
|
1061
1013
|
ggml_tensor * q,
|
|
@@ -1231,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1231
1183
|
}
|
|
1232
1184
|
|
|
1233
1185
|
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
|
1234
|
-
const
|
|
1186
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
|
1235
1187
|
|
|
1236
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams,
|
|
1188
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
|
1237
1189
|
|
|
1238
1190
|
{
|
|
1239
1191
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
|
1240
1192
|
|
|
1241
|
-
const auto n_kv =
|
|
1193
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
|
1242
1194
|
|
|
1243
1195
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1244
1196
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -1268,25 +1220,29 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1268
1220
|
ggml_build_forward_expand(gf, k_cur);
|
|
1269
1221
|
ggml_build_forward_expand(gf, v_cur);
|
|
1270
1222
|
|
|
1271
|
-
const
|
|
1223
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
|
1272
1224
|
|
|
1273
1225
|
// store to KV cache
|
|
1274
1226
|
{
|
|
1275
|
-
ggml_build_forward_expand(gf,
|
|
1276
|
-
ggml_build_forward_expand(gf,
|
|
1227
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
|
1228
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
|
1277
1229
|
}
|
|
1278
1230
|
|
|
1279
1231
|
const auto & kq_mask = inp->get_kq_mask();
|
|
1280
1232
|
|
|
1281
1233
|
ggml_tensor * q = q_cur;
|
|
1282
|
-
ggml_tensor * k =
|
|
1283
|
-
ggml_tensor * v =
|
|
1234
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1235
|
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1284
1236
|
|
|
1285
1237
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1286
1238
|
cb(cur, "kqv_out", il);
|
|
1287
1239
|
|
|
1288
1240
|
if (wo) {
|
|
1289
1241
|
cur = build_lora_mm(wo, cur);
|
|
1242
|
+
if (arch == LLM_ARCH_GLM4) {
|
|
1243
|
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
1244
|
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1245
|
+
}
|
|
1290
1246
|
}
|
|
1291
1247
|
|
|
1292
1248
|
if (wo_b) {
|
|
@@ -1296,36 +1252,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1296
1252
|
return cur;
|
|
1297
1253
|
}
|
|
1298
1254
|
|
|
1299
|
-
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
1300
|
-
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
|
1301
|
-
|
|
1302
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
|
1303
|
-
|
|
1304
|
-
{
|
|
1305
|
-
const auto n_kv = kv_self->get_kv_base()->get_n();
|
|
1306
|
-
|
|
1307
|
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1308
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1309
|
-
ggml_set_input(inp->self_kq_mask);
|
|
1310
|
-
|
|
1311
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1312
|
-
}
|
|
1313
|
-
|
|
1314
|
-
{
|
|
1315
|
-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
1316
|
-
|
|
1317
|
-
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
|
1318
|
-
|
|
1319
|
-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1320
|
-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
1321
|
-
ggml_set_input(inp->self_kq_mask_swa);
|
|
1322
|
-
|
|
1323
|
-
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1324
|
-
}
|
|
1325
|
-
|
|
1326
|
-
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
|
1327
|
-
}
|
|
1328
|
-
|
|
1329
1255
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1330
1256
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
1331
1257
|
ggml_cgraph * gf,
|
|
@@ -1344,33 +1270,29 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1344
1270
|
ggml_build_forward_expand(gf, k_cur);
|
|
1345
1271
|
ggml_build_forward_expand(gf, v_cur);
|
|
1346
1272
|
|
|
1347
|
-
const
|
|
1273
|
+
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
|
1348
1274
|
|
|
1349
|
-
const
|
|
1275
|
+
const bool is_swa = hparams.is_swa(il);
|
|
1350
1276
|
|
|
1351
|
-
const auto *
|
|
1277
|
+
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
|
1352
1278
|
|
|
1353
1279
|
// store to KV cache
|
|
1354
1280
|
{
|
|
1355
|
-
ggml_build_forward_expand(gf,
|
|
1356
|
-
ggml_build_forward_expand(gf,
|
|
1281
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
|
1282
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
|
1357
1283
|
}
|
|
1358
1284
|
|
|
1359
1285
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
1360
1286
|
|
|
1361
1287
|
ggml_tensor * q = q_cur;
|
|
1362
|
-
ggml_tensor * k =
|
|
1363
|
-
ggml_tensor * v =
|
|
1288
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1289
|
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1364
1290
|
|
|
1365
1291
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1366
1292
|
cb(cur, "kqv_out", il);
|
|
1367
1293
|
|
|
1368
1294
|
if (wo) {
|
|
1369
1295
|
cur = build_lora_mm(wo, cur);
|
|
1370
|
-
if (arch == LLM_ARCH_GLM4) {
|
|
1371
|
-
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
1372
|
-
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1373
|
-
}
|
|
1374
1296
|
}
|
|
1375
1297
|
|
|
1376
1298
|
if (wo_b) {
|
|
@@ -1439,56 +1361,182 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1439
1361
|
return cur;
|
|
1440
1362
|
}
|
|
1441
1363
|
|
|
1442
|
-
ggml_tensor * llm_graph_context::
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1364
|
+
ggml_tensor * llm_graph_context::build_attn(
|
|
1365
|
+
llm_graph_input_mem_hybrid * inp,
|
|
1366
|
+
ggml_cgraph * gf,
|
|
1367
|
+
ggml_tensor * wo,
|
|
1368
|
+
ggml_tensor * wo_b,
|
|
1369
|
+
ggml_tensor * q_cur,
|
|
1370
|
+
ggml_tensor * k_cur,
|
|
1371
|
+
ggml_tensor * v_cur,
|
|
1372
|
+
ggml_tensor * kq_b,
|
|
1373
|
+
ggml_tensor * v_mla,
|
|
1374
|
+
float kq_scale,
|
|
1375
|
+
int il) const {
|
|
1376
|
+
// these nodes are added to the graph together so that they are not reordered
|
|
1377
|
+
// by doing so, the number of splits in the graph is reduced
|
|
1378
|
+
ggml_build_forward_expand(gf, q_cur);
|
|
1379
|
+
ggml_build_forward_expand(gf, k_cur);
|
|
1380
|
+
ggml_build_forward_expand(gf, v_cur);
|
|
1381
|
+
|
|
1382
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
|
1450
1383
|
|
|
1451
|
-
|
|
1452
|
-
|
|
1384
|
+
// store to KV cache
|
|
1385
|
+
{
|
|
1386
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
|
1387
|
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
|
1388
|
+
}
|
|
1453
1389
|
|
|
1454
|
-
|
|
1390
|
+
const auto & kq_mask = inp->get_kq_mask();
|
|
1455
1391
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
states = ggml_get_rows(ctx0, states, state_copy);
|
|
1392
|
+
ggml_tensor * q = q_cur;
|
|
1393
|
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1394
|
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1460
1395
|
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
states = ggml_mul(ctx0, states, state_mask);
|
|
1396
|
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1397
|
+
cb(cur, "kqv_out", il);
|
|
1464
1398
|
|
|
1465
|
-
|
|
1399
|
+
if (wo) {
|
|
1400
|
+
cur = build_lora_mm(wo, cur);
|
|
1401
|
+
if (arch == LLM_ARCH_GLM4) {
|
|
1402
|
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
1403
|
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1404
|
+
}
|
|
1405
|
+
}
|
|
1406
|
+
|
|
1407
|
+
if (wo_b) {
|
|
1408
|
+
cur = ggml_add(ctx0, cur, wo_b);
|
|
1409
|
+
}
|
|
1410
|
+
|
|
1411
|
+
return cur;
|
|
1412
|
+
}
|
|
1413
|
+
|
|
1414
|
+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
1415
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
|
1416
|
+
|
|
1417
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
|
1418
|
+
|
|
1419
|
+
{
|
|
1420
|
+
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1421
|
+
|
|
1422
|
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1423
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1424
|
+
ggml_set_input(inp->self_kq_mask);
|
|
1425
|
+
|
|
1426
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1427
|
+
}
|
|
1428
|
+
|
|
1429
|
+
{
|
|
1430
|
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
1431
|
+
|
|
1432
|
+
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1433
|
+
|
|
1434
|
+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1435
|
+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
1436
|
+
ggml_set_input(inp->self_kq_mask_swa);
|
|
1437
|
+
|
|
1438
|
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1439
|
+
}
|
|
1440
|
+
|
|
1441
|
+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
|
1442
|
+
}
|
|
1443
|
+
|
|
1444
|
+
ggml_tensor * llm_graph_context::build_rs(
|
|
1445
|
+
ggml_cgraph * gf,
|
|
1446
|
+
ggml_tensor * s,
|
|
1447
|
+
ggml_tensor * state_copy,
|
|
1448
|
+
int32_t state_size,
|
|
1449
|
+
int32_t n_seqs,
|
|
1450
|
+
uint32_t n_kv,
|
|
1451
|
+
uint32_t kv_head,
|
|
1452
|
+
uint32_t kv_size,
|
|
1453
|
+
int32_t rs_zero,
|
|
1454
|
+
bool avoid_copies) const {
|
|
1455
|
+
|
|
1456
|
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
|
1457
|
+
|
|
1458
|
+
// Clear a single state which will then be copied to the other cleared states.
|
|
1459
|
+
// Note that this is a no-op when the view is zero-sized.
|
|
1460
|
+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
|
1461
|
+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
|
1462
|
+
|
|
1463
|
+
ggml_tensor * output_states;
|
|
1464
|
+
|
|
1465
|
+
if (!avoid_copies) {
|
|
1466
|
+
// copy states
|
|
1467
|
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
1468
|
+
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1469
|
+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1470
|
+
ggml_build_forward_expand(gf, output_states);
|
|
1471
|
+
} else {
|
|
1472
|
+
// FIXME: make the gathering operation happen before the copy below
|
|
1473
|
+
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
|
1474
|
+
output_states = states;
|
|
1475
|
+
}
|
|
1476
|
+
|
|
1477
|
+
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
|
1478
|
+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
1466
1479
|
ggml_build_forward_expand(gf,
|
|
1467
1480
|
ggml_cpy(ctx0,
|
|
1468
|
-
|
|
1469
|
-
ggml_view_1d(ctx0, s,
|
|
1481
|
+
states_extra,
|
|
1482
|
+
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
|
|
1483
|
+
|
|
1484
|
+
return output_states;
|
|
1485
|
+
}
|
|
1486
|
+
|
|
1487
|
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
1488
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1489
|
+
|
|
1490
|
+
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
|
1491
|
+
|
|
1492
|
+
const auto n_rs = mctx_cur->get_n_rs();
|
|
1493
|
+
|
|
1494
|
+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1495
|
+
ggml_set_input(inp->s_copy);
|
|
1496
|
+
|
|
1497
|
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
|
1498
|
+
}
|
|
1499
|
+
|
|
1500
|
+
ggml_tensor * llm_graph_context::build_rs(
|
|
1501
|
+
llm_graph_input_rs * inp,
|
|
1502
|
+
ggml_cgraph * gf,
|
|
1503
|
+
ggml_tensor * s,
|
|
1504
|
+
int32_t state_size,
|
|
1505
|
+
int32_t n_seqs,
|
|
1506
|
+
bool avoid_copies) const {
|
|
1507
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1508
|
+
|
|
1509
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
|
1510
|
+
}
|
|
1511
|
+
|
|
1512
|
+
ggml_tensor * llm_graph_context::build_rs(
|
|
1513
|
+
llm_graph_input_mem_hybrid * inp,
|
|
1514
|
+
ggml_cgraph * gf,
|
|
1515
|
+
ggml_tensor * s,
|
|
1516
|
+
int32_t state_size,
|
|
1517
|
+
int32_t n_seqs,
|
|
1518
|
+
bool avoid_copies) const {
|
|
1519
|
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
|
1470
1520
|
|
|
1471
|
-
|
|
1472
|
-
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
|
|
1521
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
|
1473
1522
|
}
|
|
1474
1523
|
|
|
1475
1524
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
const llama_ubatch & ubatch,
|
|
1525
|
+
llm_graph_input_rs * inp,
|
|
1526
|
+
ggml_cgraph * gf,
|
|
1527
|
+
const llama_ubatch & ubatch,
|
|
1480
1528
|
int il) const {
|
|
1481
|
-
const
|
|
1529
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1482
1530
|
|
|
1483
1531
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1484
1532
|
|
|
1485
1533
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
1486
1534
|
|
|
1487
|
-
ggml_tensor * token_shift_all =
|
|
1535
|
+
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
|
1488
1536
|
|
|
1489
|
-
ggml_tensor * token_shift =
|
|
1490
|
-
gf, token_shift_all,
|
|
1491
|
-
hparams.
|
|
1537
|
+
ggml_tensor * token_shift = build_rs(
|
|
1538
|
+
inp, gf, token_shift_all,
|
|
1539
|
+
hparams.n_embd_r(), n_seqs);
|
|
1492
1540
|
|
|
1493
1541
|
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
1494
1542
|
|
|
@@ -1499,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1499
1547
|
ggml_tensor * token_shift,
|
|
1500
1548
|
const llama_ubatch & ubatch,
|
|
1501
1549
|
int il) const {
|
|
1502
|
-
const
|
|
1550
|
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1503
1551
|
|
|
1504
1552
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1505
1553
|
const auto n_embd = hparams.n_embd;
|
|
1506
1554
|
|
|
1507
1555
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
1508
1556
|
|
|
1509
|
-
const auto kv_head =
|
|
1557
|
+
const auto kv_head = mctx_cur->get_head();
|
|
1510
1558
|
|
|
1511
1559
|
return ggml_cpy(
|
|
1512
1560
|
ctx0,
|
|
1513
1561
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
|
1514
|
-
ggml_view_1d(ctx0,
|
|
1562
|
+
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
|
|
1515
1563
|
);
|
|
1516
1564
|
}
|
|
1517
1565
|
|
|
@@ -1562,20 +1610,32 @@ void llm_graph_context::build_pooling(
|
|
|
1562
1610
|
ggml_tensor * inp_cls = build_inp_cls();
|
|
1563
1611
|
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
|
1564
1612
|
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1613
|
+
if (cls) {
|
|
1614
|
+
// classification head
|
|
1615
|
+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
1616
|
+
cur = ggml_mul_mat(ctx0, cls, inp);
|
|
1617
|
+
if (cls_b) {
|
|
1618
|
+
cur = ggml_add(ctx0, cur, cls_b);
|
|
1619
|
+
}
|
|
1620
|
+
cur = ggml_tanh(ctx0, cur);
|
|
1621
|
+
|
|
1622
|
+
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
1623
|
+
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
|
1624
|
+
if (cls_out) {
|
|
1625
|
+
cur = ggml_mul_mat(ctx0, cls_out, cur);
|
|
1626
|
+
if (cls_out_b) {
|
|
1627
|
+
cur = ggml_add(ctx0, cur, cls_out_b);
|
|
1628
|
+
}
|
|
1629
|
+
}
|
|
1630
|
+
} else if (cls_out) {
|
|
1631
|
+
// Single layer classification head (direct projection)
|
|
1632
|
+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
|
1633
|
+
cur = ggml_mul_mat(ctx0, cls_out, inp);
|
|
1634
|
+
if (cls_out_b) {
|
|
1635
|
+
cur = ggml_add(ctx0, cur, cls_out_b);
|
|
1636
|
+
}
|
|
1637
|
+
} else {
|
|
1638
|
+
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
|
1579
1639
|
}
|
|
1580
1640
|
} break;
|
|
1581
1641
|
default:
|