@fugood/llama.node 1.0.0-beta.4 → 1.0.0-beta.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -4
- package/lib/binding.ts +1 -1
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +27 -26
- package/src/LlamaCompletionWorker.cpp +21 -4
- package/src/LlamaCompletionWorker.h +2 -0
- package/src/LlamaContext.cpp +3 -12
- package/src/common.hpp +6 -5
- package/src/llama.cpp/CMakeLists.txt +15 -4
- package/src/llama.cpp/common/CMakeLists.txt +15 -24
- package/src/llama.cpp/common/arg.cpp +172 -110
- package/src/llama.cpp/common/chat-parser.cpp +385 -0
- package/src/llama.cpp/common/chat-parser.h +120 -0
- package/src/llama.cpp/common/chat.cpp +726 -596
- package/src/llama.cpp/common/chat.h +74 -8
- package/src/llama.cpp/common/common.cpp +56 -38
- package/src/llama.cpp/common/common.h +9 -3
- package/src/llama.cpp/common/json-partial.cpp +256 -0
- package/src/llama.cpp/common/json-partial.h +38 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/src/llama.cpp/common/sampling.cpp +7 -8
- package/src/llama.cpp/common/speculative.cpp +6 -4
- package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
- package/src/llama.cpp/ggml/include/ggml.h +22 -3
- package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
- package/src/llama.cpp/include/llama.h +145 -40
- package/src/llama.cpp/src/CMakeLists.txt +5 -1
- package/src/llama.cpp/src/llama-arch.cpp +99 -3
- package/src/llama.cpp/src/llama-arch.h +10 -1
- package/src/llama.cpp/src/llama-batch.cpp +728 -272
- package/src/llama.cpp/src/llama-batch.h +112 -54
- package/src/llama.cpp/src/llama-chat.cpp +19 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +525 -339
- package/src/llama.cpp/src/llama-context.h +38 -17
- package/src/llama.cpp/src/llama-cparams.cpp +4 -0
- package/src/llama.cpp/src/llama-cparams.h +2 -0
- package/src/llama.cpp/src/llama-grammar.cpp +12 -2
- package/src/llama.cpp/src/llama-graph.cpp +413 -353
- package/src/llama.cpp/src/llama-graph.h +112 -56
- package/src/llama.cpp/src/llama-hparams.cpp +10 -2
- package/src/llama.cpp/src/llama-hparams.h +13 -2
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
- package/src/llama.cpp/src/llama-kv-cells.h +415 -0
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
- package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
- package/src/llama.cpp/src/llama-memory.cpp +41 -0
- package/src/llama.cpp/src/llama-memory.h +86 -5
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/src/llama.cpp/src/llama-model.cpp +1137 -528
- package/src/llama.cpp/src/llama-model.h +4 -0
- package/src/llama.cpp/src/llama-quant.cpp +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2 -2
- package/src/llama.cpp/src/llama-vocab.cpp +69 -32
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/llama.cpp +11 -7
- package/src/llama.cpp/src/unicode.cpp +5 -0
- package/src/tts_utils.h +1 -1
- package/src/llama.cpp/common/json.hpp +0 -24766
- package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
- package/src/llama.cpp/common/minja/minja.hpp +0 -2974
- package/src/llama.cpp/common/stb_image.h +0 -7988
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
- package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
- package/src/llama.cpp/src/llama-kv-cache.h +0 -515
- /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -17,10 +17,12 @@ struct ggml_tensor;
|
|
|
17
17
|
struct llama_ubatch;
|
|
18
18
|
struct llama_cparams;
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class
|
|
23
|
-
class
|
|
20
|
+
struct llama_memory_context_i;
|
|
21
|
+
|
|
22
|
+
class llama_kv_cache_unified_context;
|
|
23
|
+
class llama_kv_cache_unified_iswa_context;
|
|
24
|
+
class llama_memory_recurrent_context;
|
|
25
|
+
class llama_memory_hybrid_context;
|
|
24
26
|
|
|
25
27
|
// certain models (typically multi-modal) can produce different types of graphs
|
|
26
28
|
enum llm_graph_type {
|
|
@@ -35,6 +37,7 @@ enum llm_ffn_op_type {
|
|
|
35
37
|
LLM_FFN_RELU,
|
|
36
38
|
LLM_FFN_RELU_SQR,
|
|
37
39
|
LLM_FFN_SWIGLU,
|
|
40
|
+
LLM_FFN_GEGLU,
|
|
38
41
|
};
|
|
39
42
|
|
|
40
43
|
enum llm_ffn_gate_type {
|
|
@@ -92,14 +95,14 @@ public:
|
|
|
92
95
|
|
|
93
96
|
class llm_graph_input_pos : public llm_graph_input_i {
|
|
94
97
|
public:
|
|
95
|
-
llm_graph_input_pos(
|
|
98
|
+
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
|
96
99
|
virtual ~llm_graph_input_pos() = default;
|
|
97
100
|
|
|
98
101
|
void set_input(const llama_ubatch * ubatch) override;
|
|
99
102
|
|
|
100
103
|
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
|
101
104
|
|
|
102
|
-
const
|
|
105
|
+
const uint32_t n_pos_per_embd = 1;
|
|
103
106
|
};
|
|
104
107
|
|
|
105
108
|
// temperature tuning, used by llama4
|
|
@@ -133,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|
|
133
136
|
public:
|
|
134
137
|
llm_graph_input_pos_bucket_kv(
|
|
135
138
|
const llama_hparams & hparams,
|
|
136
|
-
const
|
|
139
|
+
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
|
137
140
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
|
138
141
|
|
|
139
142
|
void set_input(const llama_ubatch * ubatch) override;
|
|
@@ -141,7 +144,8 @@ public:
|
|
|
141
144
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
|
142
145
|
|
|
143
146
|
const llama_hparams & hparams;
|
|
144
|
-
|
|
147
|
+
|
|
148
|
+
const llama_kv_cache_unified_context * mctx;
|
|
145
149
|
};
|
|
146
150
|
|
|
147
151
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
@@ -186,28 +190,16 @@ public:
|
|
|
186
190
|
const llama_cparams & cparams;
|
|
187
191
|
};
|
|
188
192
|
|
|
189
|
-
class
|
|
193
|
+
class llm_graph_input_rs : public llm_graph_input_i {
|
|
190
194
|
public:
|
|
191
|
-
|
|
192
|
-
virtual ~
|
|
195
|
+
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
|
|
196
|
+
virtual ~llm_graph_input_rs() = default;
|
|
193
197
|
|
|
194
198
|
void set_input(const llama_ubatch * ubatch) override;
|
|
195
199
|
|
|
196
200
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
197
201
|
|
|
198
|
-
const
|
|
199
|
-
};
|
|
200
|
-
|
|
201
|
-
class llm_graph_input_s_mask : public llm_graph_input_i {
|
|
202
|
-
public:
|
|
203
|
-
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
|
204
|
-
virtual ~llm_graph_input_s_mask() = default;
|
|
205
|
-
|
|
206
|
-
void set_input(const llama_ubatch * ubatch) override;
|
|
207
|
-
|
|
208
|
-
ggml_tensor * s_mask; // F32 [1, n_kv]
|
|
209
|
-
|
|
210
|
-
const llama_kv_cache_recurrent * kv_self;
|
|
202
|
+
const llama_memory_recurrent_context * mctx;
|
|
211
203
|
};
|
|
212
204
|
|
|
213
205
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -247,10 +239,10 @@ public:
|
|
|
247
239
|
llm_graph_input_attn_kv_unified(
|
|
248
240
|
const llama_hparams & hparams,
|
|
249
241
|
const llama_cparams & cparams,
|
|
250
|
-
const
|
|
242
|
+
const llama_kv_cache_unified_context * mctx) :
|
|
251
243
|
hparams(hparams),
|
|
252
244
|
cparams(cparams),
|
|
253
|
-
|
|
245
|
+
mctx(mctx) {
|
|
254
246
|
}
|
|
255
247
|
~llm_graph_input_attn_kv_unified() = default;
|
|
256
248
|
|
|
@@ -264,7 +256,7 @@ public:
|
|
|
264
256
|
const llama_hparams & hparams;
|
|
265
257
|
const llama_cparams & cparams;
|
|
266
258
|
|
|
267
|
-
const
|
|
259
|
+
const llama_kv_cache_unified_context * mctx;
|
|
268
260
|
};
|
|
269
261
|
|
|
270
262
|
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
@@ -272,10 +264,10 @@ public:
|
|
|
272
264
|
llm_graph_input_attn_kv_unified_iswa(
|
|
273
265
|
const llama_hparams & hparams,
|
|
274
266
|
const llama_cparams & cparams,
|
|
275
|
-
const
|
|
267
|
+
const llama_kv_cache_unified_iswa_context * mctx) :
|
|
276
268
|
hparams(hparams),
|
|
277
269
|
cparams(cparams),
|
|
278
|
-
|
|
270
|
+
mctx(mctx) {
|
|
279
271
|
}
|
|
280
272
|
~llm_graph_input_attn_kv_unified_iswa() = default;
|
|
281
273
|
|
|
@@ -292,7 +284,7 @@ public:
|
|
|
292
284
|
const llama_hparams & hparams;
|
|
293
285
|
const llama_cparams & cparams;
|
|
294
286
|
|
|
295
|
-
const
|
|
287
|
+
const llama_kv_cache_unified_iswa_context * mctx;
|
|
296
288
|
};
|
|
297
289
|
|
|
298
290
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -310,6 +302,33 @@ public:
|
|
|
310
302
|
const llama_cross * cross = nullptr;
|
|
311
303
|
};
|
|
312
304
|
|
|
305
|
+
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
|
306
|
+
public:
|
|
307
|
+
llm_graph_input_mem_hybrid(
|
|
308
|
+
const llama_hparams & hparams,
|
|
309
|
+
const llama_cparams & cparams,
|
|
310
|
+
const llama_memory_hybrid_context * mctx) :
|
|
311
|
+
hparams(hparams),
|
|
312
|
+
cparams(cparams),
|
|
313
|
+
mctx(mctx) {
|
|
314
|
+
}
|
|
315
|
+
virtual ~llm_graph_input_mem_hybrid() = default;
|
|
316
|
+
|
|
317
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
318
|
+
|
|
319
|
+
ggml_tensor * s_copy; // I32 [kv_size]
|
|
320
|
+
|
|
321
|
+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
322
|
+
|
|
323
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
|
324
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
|
325
|
+
|
|
326
|
+
const llama_hparams & hparams;
|
|
327
|
+
const llama_cparams & cparams;
|
|
328
|
+
|
|
329
|
+
const llama_memory_hybrid_context * mctx;
|
|
330
|
+
};
|
|
331
|
+
|
|
313
332
|
//
|
|
314
333
|
// llm_graph_result
|
|
315
334
|
//
|
|
@@ -383,12 +402,12 @@ struct llm_graph_params {
|
|
|
383
402
|
ggml_backend_sched_t sched;
|
|
384
403
|
ggml_backend_t backend_cpu;
|
|
385
404
|
|
|
386
|
-
const llama_adapter_cvec
|
|
387
|
-
const llama_adapter_loras
|
|
388
|
-
const
|
|
389
|
-
const llama_cross
|
|
405
|
+
const llama_adapter_cvec * cvec;
|
|
406
|
+
const llama_adapter_loras * loras;
|
|
407
|
+
const llama_memory_context_i * mctx;
|
|
408
|
+
const llama_cross * cross;
|
|
390
409
|
|
|
391
|
-
|
|
410
|
+
uint32_t n_outputs;
|
|
392
411
|
|
|
393
412
|
const llm_graph_cb & cb;
|
|
394
413
|
};
|
|
@@ -422,8 +441,8 @@ struct llm_graph_context {
|
|
|
422
441
|
const float norm_eps;
|
|
423
442
|
const float norm_rms_eps;
|
|
424
443
|
|
|
425
|
-
const
|
|
426
|
-
const
|
|
444
|
+
const int64_t n_tokens;
|
|
445
|
+
const int64_t n_outputs;
|
|
427
446
|
const int32_t n_ctx_orig; // yarn
|
|
428
447
|
|
|
429
448
|
const enum llama_pooling_type pooling_type;
|
|
@@ -435,10 +454,10 @@ struct llm_graph_context {
|
|
|
435
454
|
|
|
436
455
|
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
437
456
|
|
|
438
|
-
const llama_adapter_cvec
|
|
439
|
-
const llama_adapter_loras
|
|
440
|
-
const
|
|
441
|
-
const llama_cross
|
|
457
|
+
const llama_adapter_cvec * cvec;
|
|
458
|
+
const llama_adapter_loras * loras;
|
|
459
|
+
const llama_memory_context_i * mctx;
|
|
460
|
+
const llama_cross * cross;
|
|
442
461
|
|
|
443
462
|
const llm_graph_cb & cb_func;
|
|
444
463
|
|
|
@@ -446,8 +465,6 @@ struct llm_graph_context {
|
|
|
446
465
|
|
|
447
466
|
llm_graph_context(const llm_graph_params & params);
|
|
448
467
|
|
|
449
|
-
int64_t n_pos_per_embd() const;
|
|
450
|
-
|
|
451
468
|
void cb(ggml_tensor * cur, const char * name, int il) const;
|
|
452
469
|
|
|
453
470
|
//
|
|
@@ -518,14 +535,14 @@ struct llm_graph_context {
|
|
|
518
535
|
ggml_tensor * build_inp_out_ids() const;
|
|
519
536
|
ggml_tensor * build_inp_mean() const;
|
|
520
537
|
ggml_tensor * build_inp_cls() const;
|
|
521
|
-
ggml_tensor * build_inp_s_copy() const;
|
|
522
|
-
ggml_tensor * build_inp_s_mask() const;
|
|
523
538
|
|
|
524
539
|
ggml_tensor * build_inp_cross_embd() const;
|
|
525
540
|
ggml_tensor * build_inp_pos_bucket_enc() const;
|
|
526
541
|
ggml_tensor * build_inp_pos_bucket_dec() const;
|
|
527
542
|
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
|
528
543
|
|
|
544
|
+
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
|
545
|
+
|
|
529
546
|
//
|
|
530
547
|
// attention
|
|
531
548
|
//
|
|
@@ -600,23 +617,62 @@ struct llm_graph_context {
|
|
|
600
617
|
float kq_scale,
|
|
601
618
|
int il) const;
|
|
602
619
|
|
|
620
|
+
ggml_tensor * build_attn(
|
|
621
|
+
llm_graph_input_mem_hybrid * inp,
|
|
622
|
+
ggml_cgraph * gf,
|
|
623
|
+
ggml_tensor * wo,
|
|
624
|
+
ggml_tensor * wo_b,
|
|
625
|
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
626
|
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
627
|
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
628
|
+
ggml_tensor * kq_b,
|
|
629
|
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
630
|
+
float kq_scale,
|
|
631
|
+
int il) const;
|
|
603
632
|
//
|
|
604
633
|
// recurrent
|
|
605
634
|
//
|
|
606
635
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
636
|
+
// TODO: avoid notion of "kv"
|
|
637
|
+
// TODO: move this implementation to llama_memory_recurrent.
|
|
638
|
+
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
|
639
|
+
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
|
640
|
+
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
|
641
|
+
// `llama_memory_recurrent`
|
|
642
|
+
ggml_tensor * build_rs(
|
|
643
|
+
ggml_cgraph * gf,
|
|
644
|
+
ggml_tensor * s,
|
|
645
|
+
ggml_tensor * state_copy,
|
|
646
|
+
int32_t state_size,
|
|
647
|
+
int32_t n_seqs,
|
|
648
|
+
uint32_t n_kv,
|
|
649
|
+
uint32_t kv_head,
|
|
650
|
+
uint32_t kv_size,
|
|
651
|
+
int32_t rs_zero,
|
|
652
|
+
bool avoid_copies = false) const;
|
|
653
|
+
|
|
654
|
+
llm_graph_input_rs * build_rs_inp() const;
|
|
655
|
+
|
|
656
|
+
ggml_tensor * build_rs(
|
|
657
|
+
llm_graph_input_rs * inp,
|
|
658
|
+
ggml_cgraph * gf,
|
|
659
|
+
ggml_tensor * s,
|
|
660
|
+
int32_t state_size,
|
|
661
|
+
int32_t n_seqs,
|
|
662
|
+
bool avoid_copies = false) const;
|
|
663
|
+
|
|
664
|
+
ggml_tensor * build_rs(
|
|
665
|
+
llm_graph_input_mem_hybrid * inp,
|
|
666
|
+
ggml_cgraph * gf,
|
|
667
|
+
ggml_tensor * s,
|
|
668
|
+
int32_t state_size,
|
|
669
|
+
int32_t n_seqs,
|
|
670
|
+
bool avoid_copies = false) const;
|
|
614
671
|
|
|
615
672
|
ggml_tensor * build_rwkv_token_shift_load(
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
const llama_ubatch & ubatch,
|
|
673
|
+
llm_graph_input_rs * inp,
|
|
674
|
+
ggml_cgraph * gf,
|
|
675
|
+
const llama_ubatch & ubatch,
|
|
620
676
|
int il) const;
|
|
621
677
|
|
|
622
678
|
ggml_tensor * build_rwkv_token_shift_store(
|
|
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|
|
65
65
|
return n_embd_head_v * n_head_kv;
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
-
uint32_t llama_hparams::
|
|
68
|
+
uint32_t llama_hparams::n_embd_r() const {
|
|
69
69
|
if (wkv_head_size != 0) {
|
|
70
70
|
// for RWKV models
|
|
71
71
|
return token_shift_count * n_embd;
|
|
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
|
|
76
76
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
|
77
77
|
}
|
|
78
78
|
|
|
79
|
-
uint32_t llama_hparams::
|
|
79
|
+
uint32_t llama_hparams::n_embd_s() const {
|
|
80
80
|
if (wkv_head_size != 0) {
|
|
81
81
|
// corresponds to RWKV's wkv_states size
|
|
82
82
|
return n_embd * wkv_head_size;
|
|
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|
|
86
86
|
return ssm_d_state * ssm_d_inner;
|
|
87
87
|
}
|
|
88
88
|
|
|
89
|
+
bool llama_hparams::is_recurrent(uint32_t il) const {
|
|
90
|
+
return recurrent_layer_arr[il];
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
uint32_t llama_hparams::n_pos_per_embd() const {
|
|
94
|
+
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
|
95
|
+
}
|
|
96
|
+
|
|
89
97
|
bool llama_hparams::is_swa(uint32_t il) const {
|
|
90
98
|
if (il < n_layer) {
|
|
91
99
|
return swa_layers[il];
|
|
@@ -115,6 +115,9 @@ struct llama_hparams {
|
|
|
115
115
|
uint32_t ssm_d_state = 0;
|
|
116
116
|
uint32_t ssm_dt_rank = 0;
|
|
117
117
|
|
|
118
|
+
// for hybrid state space models
|
|
119
|
+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
|
120
|
+
|
|
118
121
|
bool ssm_dt_b_c_rms = false;
|
|
119
122
|
|
|
120
123
|
float f_clamp_kqv = 0.0f;
|
|
@@ -131,6 +134,9 @@ struct llama_hparams {
|
|
|
131
134
|
bool attn_soft_cap = false;
|
|
132
135
|
bool use_kq_norm = true;
|
|
133
136
|
|
|
137
|
+
// for Classifiers
|
|
138
|
+
uint32_t n_cls_out = 1;
|
|
139
|
+
|
|
134
140
|
// llama4
|
|
135
141
|
uint32_t n_moe_layer_step = 0;
|
|
136
142
|
uint32_t n_no_rope_layer_step = 4;
|
|
@@ -178,10 +184,15 @@ struct llama_hparams {
|
|
|
178
184
|
|
|
179
185
|
// dimension of the rolling state embeddings
|
|
180
186
|
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
|
181
|
-
uint32_t
|
|
187
|
+
uint32_t n_embd_r() const;
|
|
182
188
|
|
|
183
189
|
// dimension of the recurrent state embeddings
|
|
184
|
-
uint32_t
|
|
190
|
+
uint32_t n_embd_s() const;
|
|
191
|
+
|
|
192
|
+
// whether or not the given layer is recurrent (for hybrid models)
|
|
193
|
+
bool is_recurrent(uint32_t il) const;
|
|
194
|
+
|
|
195
|
+
uint32_t n_pos_per_embd() const;
|
|
185
196
|
|
|
186
197
|
bool is_swa(uint32_t il) const;
|
|
187
198
|
};
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
#include "llama-kv-cache-unified-iswa.h"
|
|
2
|
+
|
|
3
|
+
#include "llama-impl.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
5
|
+
#include "llama-model.h"
|
|
6
|
+
|
|
7
|
+
#include <algorithm>
|
|
8
|
+
#include <cassert>
|
|
9
|
+
|
|
10
|
+
//
|
|
11
|
+
// llama_kv_cache_unified_iswa
|
|
12
|
+
//
|
|
13
|
+
|
|
14
|
+
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
15
|
+
const llama_model & model,
|
|
16
|
+
ggml_type type_k,
|
|
17
|
+
ggml_type type_v,
|
|
18
|
+
bool v_trans,
|
|
19
|
+
bool offload,
|
|
20
|
+
bool swa_full,
|
|
21
|
+
uint32_t kv_size,
|
|
22
|
+
uint32_t n_seq_max,
|
|
23
|
+
uint32_t n_ubatch,
|
|
24
|
+
uint32_t n_pad) : hparams(model.hparams) {
|
|
25
|
+
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
|
26
|
+
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
|
27
|
+
|
|
28
|
+
const uint32_t size_base = kv_size;
|
|
29
|
+
|
|
30
|
+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
|
31
|
+
|
|
32
|
+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
33
|
+
if (swa_full) {
|
|
34
|
+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
|
35
|
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
36
|
+
|
|
37
|
+
size_swa = size_base;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
|
41
|
+
|
|
42
|
+
kv_base = std::make_unique<llama_kv_cache_unified>(
|
|
43
|
+
model, std::move(filter_base), type_k, type_v,
|
|
44
|
+
v_trans, offload, size_base, n_seq_max, n_pad,
|
|
45
|
+
0, LLAMA_SWA_TYPE_NONE);
|
|
46
|
+
|
|
47
|
+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
48
|
+
|
|
49
|
+
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
|
50
|
+
model, std::move(filter_swa), type_k, type_v,
|
|
51
|
+
v_trans, offload, size_swa, n_seq_max, n_pad,
|
|
52
|
+
hparams.n_swa, hparams.swa_type);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
void llama_kv_cache_unified_iswa::clear(bool data) {
|
|
56
|
+
kv_base->clear(data);
|
|
57
|
+
kv_swa ->clear(data);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
61
|
+
bool res = true;
|
|
62
|
+
|
|
63
|
+
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
|
64
|
+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
|
65
|
+
|
|
66
|
+
return res;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
70
|
+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
71
|
+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
|
75
|
+
kv_base->seq_keep(seq_id);
|
|
76
|
+
kv_swa ->seq_keep(seq_id);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
80
|
+
kv_base->seq_add(seq_id, p0, p1, shift);
|
|
81
|
+
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
85
|
+
kv_base->seq_div(seq_id, p0, p1, d);
|
|
86
|
+
kv_swa ->seq_div(seq_id, p0, p1, d);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
|
90
|
+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
|
91
|
+
return kv_swa->seq_pos_min(seq_id);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
95
|
+
return kv_swa->seq_pos_max(seq_id);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
99
|
+
GGML_UNUSED(embd_all);
|
|
100
|
+
|
|
101
|
+
// first try simple split
|
|
102
|
+
do {
|
|
103
|
+
balloc.split_reset();
|
|
104
|
+
|
|
105
|
+
std::vector<llama_ubatch> ubatches;
|
|
106
|
+
while (true) {
|
|
107
|
+
auto ubatch = balloc.split_simple(n_ubatch);
|
|
108
|
+
|
|
109
|
+
if (ubatch.n_tokens == 0) {
|
|
110
|
+
break;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
auto heads_base = kv_base->prepare(ubatches);
|
|
117
|
+
if (heads_base.empty()) {
|
|
118
|
+
break;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
auto heads_swa = kv_swa->prepare(ubatches);
|
|
122
|
+
if (heads_swa.empty()) {
|
|
123
|
+
break;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
assert(heads_base.size() == heads_swa.size());
|
|
127
|
+
|
|
128
|
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
|
129
|
+
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
130
|
+
} while (false);
|
|
131
|
+
|
|
132
|
+
// if it fails, try equal split
|
|
133
|
+
do {
|
|
134
|
+
balloc.split_reset();
|
|
135
|
+
|
|
136
|
+
std::vector<llama_ubatch> ubatches;
|
|
137
|
+
while (true) {
|
|
138
|
+
auto ubatch = balloc.split_equal(n_ubatch);
|
|
139
|
+
|
|
140
|
+
if (ubatch.n_tokens == 0) {
|
|
141
|
+
break;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
auto heads_base = kv_base->prepare(ubatches);
|
|
148
|
+
if (heads_base.empty()) {
|
|
149
|
+
break;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
auto heads_swa = kv_swa->prepare(ubatches);
|
|
153
|
+
if (heads_swa.empty()) {
|
|
154
|
+
break;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
assert(heads_base.size() == heads_swa.size());
|
|
158
|
+
|
|
159
|
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
|
160
|
+
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
161
|
+
} while (false);
|
|
162
|
+
|
|
163
|
+
// TODO: if we fail again, we should attempt different splitting strategies
|
|
164
|
+
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
165
|
+
|
|
166
|
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
|
170
|
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
|
174
|
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
178
|
+
return kv_base->get_size() == kv_swa->get_size();
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
182
|
+
kv_base->state_write(io, seq_id);
|
|
183
|
+
kv_swa ->state_write(io, seq_id);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
187
|
+
kv_base->state_read(io, seq_id);
|
|
188
|
+
kv_swa ->state_read(io, seq_id);
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
|
192
|
+
return kv_base.get();
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
196
|
+
return kv_swa.get();
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
//
|
|
200
|
+
// llama_kv_cache_unified_iswa_context
|
|
201
|
+
//
|
|
202
|
+
|
|
203
|
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
|
204
|
+
|
|
205
|
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
206
|
+
llama_kv_cache_unified_iswa * kv) :
|
|
207
|
+
ctx_base(kv->get_base()->init_full()),
|
|
208
|
+
ctx_swa (kv->get_swa ()->init_full()),
|
|
209
|
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
213
|
+
llama_kv_cache_unified_iswa * kv,
|
|
214
|
+
llama_context * lctx,
|
|
215
|
+
bool optimize) :
|
|
216
|
+
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
|
217
|
+
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
|
218
|
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
222
|
+
llama_kv_cache_unified_iswa * kv,
|
|
223
|
+
std::vector<uint32_t> heads_base,
|
|
224
|
+
std::vector<uint32_t> heads_swa,
|
|
225
|
+
std::vector<llama_ubatch> ubatches) :
|
|
226
|
+
ubatches(std::move(ubatches)),
|
|
227
|
+
// note: here we copy the ubatches. not sure if this is ideal
|
|
228
|
+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
|
229
|
+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
|
230
|
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
|
234
|
+
|
|
235
|
+
bool llama_kv_cache_unified_iswa_context::next() {
|
|
236
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
237
|
+
|
|
238
|
+
ctx_base->next();
|
|
239
|
+
ctx_swa ->next();
|
|
240
|
+
|
|
241
|
+
if (++i_next >= ubatches.size()) {
|
|
242
|
+
return false;
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
return true;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
bool llama_kv_cache_unified_iswa_context::apply() {
|
|
249
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
250
|
+
|
|
251
|
+
bool res = true;
|
|
252
|
+
|
|
253
|
+
res = res & ctx_base->apply();
|
|
254
|
+
res = res & ctx_swa ->apply();
|
|
255
|
+
|
|
256
|
+
return res;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
|
260
|
+
return status;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
|
264
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
265
|
+
|
|
266
|
+
return ubatches[i_next];
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
|
270
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
271
|
+
|
|
272
|
+
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
|
276
|
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
277
|
+
|
|
278
|
+
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
|
279
|
+
}
|