@fugood/llama.node 1.1.7 → 1.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/lib/binding.ts +4 -0
  2. package/lib/index.js +9 -2
  3. package/lib/index.ts +57 -30
  4. package/lib/version.js +2 -2
  5. package/lib/version.ts +2 -2
  6. package/package.json +14 -14
  7. package/src/LlamaContext.cpp +20 -0
  8. package/src/common.hpp +8 -1
  9. package/src/llama.cpp/common/arg.cpp +13 -4
  10. package/src/llama.cpp/common/chat.cpp +33 -2
  11. package/src/llama.cpp/common/common.cpp +0 -15
  12. package/src/llama.cpp/common/common.h +6 -4
  13. package/src/llama.cpp/ggml/CMakeLists.txt +0 -1
  14. package/src/llama.cpp/ggml/include/ggml.h +25 -0
  15. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +66 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +316 -0
  17. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -3
  18. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
  20. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +142 -0
  21. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  22. package/src/llama.cpp/include/llama.h +1 -110
  23. package/src/llama.cpp/src/CMakeLists.txt +2 -2
  24. package/src/llama.cpp/src/llama-arch.cpp +19 -0
  25. package/src/llama.cpp/src/llama-arch.h +1 -0
  26. package/src/llama.cpp/src/llama-chat.cpp +13 -2
  27. package/src/llama.cpp/src/llama-chat.h +1 -0
  28. package/src/llama.cpp/src/llama-context.cpp +5 -197
  29. package/src/llama.cpp/src/llama-context.h +2 -7
  30. package/src/llama.cpp/src/llama-cparams.h +0 -1
  31. package/src/llama.cpp/src/llama-graph.cpp +35 -57
  32. package/src/llama.cpp/src/llama-graph.h +36 -46
  33. package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +47 -47
  34. package/src/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +26 -26
  35. package/src/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +88 -441
  36. package/src/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +20 -43
  37. package/src/llama.cpp/src/llama-kv-cells.h +21 -21
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +5 -5
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +6 -6
  40. package/src/llama.cpp/src/llama-memory-recurrent.h +1 -1
  41. package/src/llama.cpp/src/llama-memory.h +3 -8
  42. package/src/llama.cpp/src/llama-model.cpp +449 -246
  43. package/src/llama.cpp/src/llama-model.h +2 -0
@@ -20,8 +20,8 @@ add_library(llama
20
20
  llama-hparams.cpp
21
21
  llama-impl.cpp
22
22
  llama-io.cpp
23
- llama-kv-cache-unified.cpp
24
- llama-kv-cache-unified-iswa.cpp
23
+ llama-kv-cache.cpp
24
+ llama-kv-cache-iswa.cpp
25
25
  llama-memory.cpp
26
26
  llama-memory-hybrid.cpp
27
27
  llama-memory-recurrent.cpp
@@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
93
93
  { LLM_ARCH_DREAM, "dream" },
94
94
  { LLM_ARCH_SMALLTHINKER, "smallthinker" },
95
95
  { LLM_ARCH_LLADA, "llada" },
96
+ { LLM_ARCH_SEED_OSS, "seed_oss" },
96
97
  { LLM_ARCH_UNKNOWN, "(unknown)" },
97
98
  };
98
99
 
@@ -2010,6 +2011,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
2010
2011
  { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
2011
2012
  { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2012
2013
  { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
2014
+ { LLM_TENSOR_OUTPUT, "output" },
2013
2015
  }
2014
2016
  },
2015
2017
  {
@@ -2067,6 +2069,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
2067
2069
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2068
2070
  },
2069
2071
  },
2072
+ {
2073
+ LLM_ARCH_SEED_OSS,
2074
+ {
2075
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2076
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2077
+ { LLM_TENSOR_OUTPUT, "output" },
2078
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2079
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2080
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2081
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2082
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2083
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
2084
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2085
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2086
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2087
+ },
2088
+ },
2070
2089
  {
2071
2090
  LLM_ARCH_UNKNOWN,
2072
2091
  {
@@ -97,6 +97,7 @@ enum llm_arch {
97
97
  LLM_ARCH_DREAM,
98
98
  LLM_ARCH_SMALLTHINKER,
99
99
  LLM_ARCH_LLADA,
100
+ LLM_ARCH_SEED_OSS,
100
101
  LLM_ARCH_UNKNOWN,
101
102
  };
102
103
 
@@ -16,10 +16,10 @@
16
16
  static std::string trim(const std::string & str) {
17
17
  size_t start = 0;
18
18
  size_t end = str.size();
19
- while (start < end && isspace(str[start])) {
19
+ while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
20
20
  start += 1;
21
21
  }
22
- while (end > start && isspace(str[end - 1])) {
22
+ while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
23
23
  end -= 1;
24
24
  }
25
25
  return str.substr(start, end - start);
@@ -69,6 +69,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
69
69
  { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
70
70
  { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
71
71
  { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
72
+ { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
72
73
  };
73
74
 
74
75
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -201,6 +202,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
201
202
  return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
202
203
  } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
203
204
  return LLM_CHAT_TEMPLATE_KIMI_K2;
205
+ } else if (tmpl_contains("<seed:bos>")) {
206
+ return LLM_CHAT_TEMPLATE_SEED_OSS;
204
207
  }
205
208
  return LLM_CHAT_TEMPLATE_UNKNOWN;
206
209
  }
@@ -752,6 +755,14 @@ int32_t llm_chat_apply_template(
752
755
  if (add_ass) {
753
756
  ss << "<|im_assistant|>assistant<|im_middle|>";
754
757
  }
758
+ } else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
759
+ for (auto message: chat) {
760
+ std::string role(message->role);
761
+ ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
762
+ }
763
+ if (add_ass) {
764
+ ss << "<seed:bos>assistant\n";
765
+ }
755
766
  } else {
756
767
  // template not supported
757
768
  return -1;
@@ -49,6 +49,7 @@ enum llm_chat_template {
49
49
  LLM_CHAT_TEMPLATE_OPENAI_MOE,
50
50
  LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
51
51
  LLM_CHAT_TEMPLATE_KIMI_K2,
52
+ LLM_CHAT_TEMPLATE_SEED_OSS,
52
53
  LLM_CHAT_TEMPLATE_UNKNOWN,
53
54
  };
54
55
 
@@ -39,7 +39,6 @@ llama_context::llama_context(
39
39
  cparams.yarn_attn_factor = params.yarn_attn_factor;
40
40
  cparams.yarn_beta_fast = params.yarn_beta_fast;
41
41
  cparams.yarn_beta_slow = params.yarn_beta_slow;
42
- cparams.defrag_thold = params.defrag_thold;
43
42
  cparams.embeddings = params.embeddings;
44
43
  cparams.offload_kqv = params.offload_kqv;
45
44
  cparams.flash_attn = params.flash_attn;
@@ -93,7 +92,7 @@ llama_context::llama_context(
93
92
  // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
94
93
  // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
95
94
  // ref: https://github.com/ggerganov/llama.cpp/pull/5021
96
- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
95
+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
97
96
  if (cparams.n_batch < GGML_KQ_MASK_PAD) {
98
97
  LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
98
  cparams.n_batch = GGML_KQ_MASK_PAD;
@@ -145,11 +144,6 @@ llama_context::llama_context(
145
144
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
146
145
  }
147
146
 
148
- if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
149
- LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
150
- __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
151
- }
152
-
153
147
  if (!hparams.vocab_only) {
154
148
  // GPU backends
155
149
  for (auto * dev : model.devices) {
@@ -444,26 +438,12 @@ llama_memory_t llama_context::get_memory() const {
444
438
  return memory.get();
445
439
  }
446
440
 
447
- // deprecated
448
- void llama_context::kv_self_defrag_sched() {
449
- if (!memory) {
450
- return;
451
- }
452
-
453
- memory_force_optimize = true;
454
- }
455
-
456
- // deprecated
457
- bool llama_context::kv_self_update(bool optimize) {
441
+ bool llama_context::memory_update(bool optimize) {
458
442
  if (!memory) {
459
443
  return false;
460
444
  }
461
445
 
462
446
  {
463
- // TODO: remove in the future
464
- optimize |= memory_force_optimize;
465
- memory_force_optimize = false;
466
-
467
447
  const auto mctx = memory->init_update(this, optimize);
468
448
  switch (mctx->get_status()) {
469
449
  case LLAMA_MEMORY_STATUS_SUCCESS:
@@ -997,8 +977,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
997
977
 
998
978
  bool did_optimize = false;
999
979
 
1000
- // handle any pending defrags/shifts
1001
- kv_self_update(false);
980
+ // handle any pending shifts/copies
981
+ memory_update(false);
1002
982
 
1003
983
  llama_memory_context_ptr mctx;
1004
984
 
@@ -1023,7 +1003,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1023
1003
  if (!did_optimize) {
1024
1004
  did_optimize = true;
1025
1005
 
1026
- if (kv_self_update(true)) {
1006
+ if (memory_update(true)) {
1027
1007
  LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
1028
1008
 
1029
1009
  continue;
@@ -2343,16 +2323,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2343
2323
  return &ctx->get_model();
2344
2324
  }
2345
2325
 
2346
- // deprecated
2347
- llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2348
- return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2349
- }
2350
-
2351
- // deprecated
2352
- void llama_kv_self_update(llama_context * ctx) {
2353
- ctx->kv_self_update(false);
2354
- }
2355
-
2356
2326
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
2357
2327
  return ctx->pooling_type();
2358
2328
  }
@@ -2570,168 +2540,6 @@ bool llama_memory_can_shift(llama_memory_t mem) {
2570
2540
  return mem->get_can_shift();
2571
2541
  }
2572
2542
 
2573
- //
2574
- // kv cache
2575
- //
2576
-
2577
- // deprecated
2578
- int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2579
- const auto * kv = llama_get_memory(ctx);
2580
- if (!kv) {
2581
- return 0;
2582
- }
2583
-
2584
- int32_t res = 0;
2585
-
2586
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2587
- const llama_pos p0 = kv->seq_pos_min(s);
2588
- const llama_pos p1 = kv->seq_pos_max(s);
2589
-
2590
- if (p0 >= 0) {
2591
- res += (p1 - p0) + 1;
2592
- }
2593
- }
2594
-
2595
- return res;
2596
- }
2597
-
2598
- // deprecated
2599
- // note: this is the same as above - will be removed anyway, so it's ok
2600
- int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2601
- const auto * kv = llama_get_memory(ctx);
2602
- if (!kv) {
2603
- return 0;
2604
- }
2605
-
2606
- int32_t res = 0;
2607
-
2608
- for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2609
- const llama_pos p0 = kv->seq_pos_min(s);
2610
- const llama_pos p1 = kv->seq_pos_max(s);
2611
-
2612
- if (p0 >= 0) {
2613
- res += (p1 - p0) + 1;
2614
- }
2615
- }
2616
-
2617
- return res;
2618
- }
2619
-
2620
- // deprecated
2621
- void llama_kv_self_clear(llama_context * ctx) {
2622
- auto * kv = llama_get_memory(ctx);
2623
- if (!kv) {
2624
- return;
2625
- }
2626
-
2627
- llama_memory_clear(kv, true);
2628
- }
2629
-
2630
- // deprecated
2631
- bool llama_kv_self_seq_rm(
2632
- llama_context * ctx,
2633
- llama_seq_id seq_id,
2634
- llama_pos p0,
2635
- llama_pos p1) {
2636
- auto * kv = llama_get_memory(ctx);
2637
- if (!kv) {
2638
- return true;
2639
- }
2640
-
2641
- return llama_memory_seq_rm(kv, seq_id, p0, p1);
2642
- }
2643
-
2644
- // deprecated
2645
- void llama_kv_self_seq_cp(
2646
- llama_context * ctx,
2647
- llama_seq_id seq_id_src,
2648
- llama_seq_id seq_id_dst,
2649
- llama_pos p0,
2650
- llama_pos p1) {
2651
- auto * kv = llama_get_memory(ctx);
2652
- if (!kv) {
2653
- return;
2654
- }
2655
-
2656
- llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2657
- }
2658
-
2659
- // deprecated
2660
- void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2661
- auto * kv = llama_get_memory(ctx);
2662
- if (!kv) {
2663
- return;
2664
- }
2665
-
2666
- llama_memory_seq_keep(kv, seq_id);
2667
- }
2668
-
2669
- // deprecated
2670
- void llama_kv_self_seq_add(
2671
- llama_context * ctx,
2672
- llama_seq_id seq_id,
2673
- llama_pos p0,
2674
- llama_pos p1,
2675
- llama_pos delta) {
2676
- auto * kv = llama_get_memory(ctx);
2677
- if (!kv) {
2678
- return;
2679
- }
2680
-
2681
- llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2682
- }
2683
-
2684
- // deprecated
2685
- void llama_kv_self_seq_div(
2686
- llama_context * ctx,
2687
- llama_seq_id seq_id,
2688
- llama_pos p0,
2689
- llama_pos p1,
2690
- int d) {
2691
- auto * kv = llama_get_memory(ctx);
2692
- if (!kv) {
2693
- return;
2694
- }
2695
-
2696
- llama_memory_seq_div(kv, seq_id, p0, p1, d);
2697
- }
2698
-
2699
- // deprecated
2700
- llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2701
- auto * kv = llama_get_memory(ctx);
2702
- if (!kv) {
2703
- return -1;
2704
- }
2705
-
2706
- return llama_memory_seq_pos_min(kv, seq_id);
2707
- }
2708
-
2709
- // deprecated
2710
- llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2711
- auto * kv = llama_get_memory(ctx);
2712
- if (!kv) {
2713
- return -1;
2714
- }
2715
-
2716
- return llama_memory_seq_pos_max(kv, seq_id);
2717
- }
2718
-
2719
- // deprecated
2720
- void llama_kv_self_defrag(llama_context * ctx) {
2721
- // force defrag
2722
- ctx->kv_self_defrag_sched();
2723
- }
2724
-
2725
- // deprecated
2726
- bool llama_kv_self_can_shift(const llama_context * ctx) {
2727
- auto * kv = llama_get_memory(ctx);
2728
- if (!kv) {
2729
- return false;
2730
- }
2731
-
2732
- return llama_memory_can_shift(kv);
2733
- }
2734
-
2735
2543
  // llama state API
2736
2544
 
2737
2545
  // deprecated
@@ -46,10 +46,8 @@ struct llama_context {
46
46
 
47
47
  llama_memory_t get_memory() const;
48
48
 
49
- // return true of the KV cache was updated
50
- // TODO: remove
51
- bool kv_self_update(bool optimize);
52
- void kv_self_defrag_sched();
49
+ // return true if the memory was updated
50
+ bool memory_update(bool optimize);
53
51
 
54
52
  enum llama_pooling_type pooling_type() const;
55
53
 
@@ -230,9 +228,6 @@ private:
230
228
 
231
229
  std::unique_ptr<llama_memory_i> memory;
232
230
 
233
- // TODO: temporary, until the llama_kv_self_defrag() API is removed
234
- bool memory_force_optimize = false;
235
-
236
231
  // decode output (2-dimensional array: [n_outputs][n_vocab])
237
232
  size_t logits_size = 0; // capacity (of floats) for logits
238
233
  float * logits = nullptr;
@@ -24,7 +24,6 @@ struct llama_cparams {
24
24
  float yarn_attn_factor;
25
25
  float yarn_beta_fast;
26
26
  float yarn_beta_slow;
27
- float defrag_thold;
28
27
 
29
28
  bool embeddings;
30
29
  bool causal_attn;
@@ -4,8 +4,8 @@
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
6
 
7
- #include "llama-kv-cache-unified.h"
8
- #include "llama-kv-cache-unified-iswa.h"
7
+ #include "llama-kv-cache.h"
8
+ #include "llama-kv-cache-iswa.h"
9
9
  #include "llama-memory-hybrid.h"
10
10
  #include "llama-memory-recurrent.h"
11
11
 
@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277
277
  for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
278
278
  const llama_seq_id s0 = ubatch->seq_id[i0][0];
279
279
 
280
- // TODO: reimplement this like in llama_kv_cache_unified
280
+ // TODO: reimplement this like in llama_kv_cache
281
281
  if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
282
282
  if (hparams.use_alibi) {
283
283
  f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
294
294
  }
295
295
  }
296
296
 
297
- void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
297
+ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
298
298
  mctx->set_input_k_idxs(self_k_idxs, ubatch);
299
299
  mctx->set_input_v_idxs(self_v_idxs, ubatch);
300
300
 
301
301
  mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
302
302
  }
303
303
 
304
- bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
305
- const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
304
+ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
305
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
306
306
 
307
307
  this->mctx = mctx;
308
308
 
@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
319
319
  return res;
320
320
  }
321
321
 
322
- void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
322
+ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
323
323
  mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
324
324
  mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
325
325
 
@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
331
331
  mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
332
332
  }
333
333
 
334
- bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
335
- const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
334
+ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
335
+ const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
336
336
 
337
337
  this->mctx = mctx;
338
338
 
@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1186
1186
  }
1187
1187
 
1188
1188
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1189
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1189
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1190
1190
 
1191
1191
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1192
1192
 
@@ -1223,8 +1223,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1223
1223
  ggml_tensor * v,
1224
1224
  ggml_tensor * kq_b,
1225
1225
  ggml_tensor * kq_mask,
1226
- ggml_tensor * v_mla,
1227
1226
  ggml_tensor * sinks,
1227
+ ggml_tensor * v_mla,
1228
1228
  float kq_scale) const {
1229
1229
  const bool v_trans = v->nb[1] > v->nb[2];
1230
1230
 
@@ -1360,6 +1360,7 @@ ggml_tensor * llm_graph_context::build_attn(
1360
1360
  ggml_tensor * k_cur,
1361
1361
  ggml_tensor * v_cur,
1362
1362
  ggml_tensor * kq_b,
1363
+ ggml_tensor * sinks,
1363
1364
  ggml_tensor * v_mla,
1364
1365
  float kq_scale,
1365
1366
  int il) const {
@@ -1381,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
1381
1382
  ggml_tensor * k = k_cur;
1382
1383
  ggml_tensor * v = v_cur;
1383
1384
 
1384
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1385
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1385
1386
  cb(cur, "kqv_out", il);
1386
1387
 
1387
1388
  if (wo) {
@@ -1399,17 +1400,17 @@ ggml_tensor * llm_graph_context::build_attn(
1399
1400
  return cur;
1400
1401
  }
1401
1402
 
1402
- static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
1403
+ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1403
1404
  ggml_context * ctx0,
1404
1405
  const llama_ubatch & ubatch,
1405
1406
  const llama_hparams & hparams,
1406
1407
  const llama_cparams & cparams,
1407
- const llama_kv_cache_unified_context * mctx_cur) {
1408
+ const llama_kv_cache_context * mctx_cur) {
1408
1409
 
1409
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1410
+ auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1410
1411
 
1411
1412
  {
1412
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1413
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1413
1414
 
1414
1415
  const auto n_kv = mctx_cur->get_n_kv();
1415
1416
  const auto n_tokens = ubatch.n_tokens;
@@ -1427,22 +1428,23 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
1427
1428
  return inp;
1428
1429
  }
1429
1430
 
1430
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1431
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1431
+ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1432
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1432
1433
 
1433
- auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1434
+ auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1434
1435
 
1435
- return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1436
+ return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1436
1437
  }
1437
1438
 
1438
1439
  ggml_tensor * llm_graph_context::build_attn(
1439
- llm_graph_input_attn_kv_unified * inp,
1440
+ llm_graph_input_attn_kv * inp,
1440
1441
  ggml_tensor * wo,
1441
1442
  ggml_tensor * wo_b,
1442
1443
  ggml_tensor * q_cur,
1443
1444
  ggml_tensor * k_cur,
1444
1445
  ggml_tensor * v_cur,
1445
1446
  ggml_tensor * kq_b,
1447
+ ggml_tensor * sinks,
1446
1448
  ggml_tensor * v_mla,
1447
1449
  float kq_scale,
1448
1450
  int il) const {
@@ -1469,7 +1471,7 @@ ggml_tensor * llm_graph_context::build_attn(
1469
1471
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1470
1472
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1471
1473
 
1472
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1474
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1473
1475
  cb(cur, "kqv_out", il);
1474
1476
 
1475
1477
  if (wo) {
@@ -1488,40 +1490,15 @@ ggml_tensor * llm_graph_context::build_attn(
1488
1490
  }
1489
1491
 
1490
1492
  ggml_tensor * llm_graph_context::build_attn(
1491
- llm_graph_input_attn_kv_unified_iswa * inp,
1492
- ggml_tensor * wo,
1493
- ggml_tensor * wo_b,
1494
- ggml_tensor * q_cur,
1495
- ggml_tensor * k_cur,
1496
- ggml_tensor * v_cur,
1497
- ggml_tensor * kq_b,
1498
- ggml_tensor * v_mla,
1499
- float kq_scale,
1500
- int il) const {
1501
- return build_attn_with_sinks(
1502
- inp,
1503
- wo,
1504
- wo_b,
1505
- q_cur,
1506
- k_cur,
1507
- v_cur,
1508
- kq_b,
1509
- v_mla,
1510
- nullptr,
1511
- kq_scale,
1512
- il);
1513
- }
1514
-
1515
- ggml_tensor * llm_graph_context::build_attn_with_sinks(
1516
- llm_graph_input_attn_kv_unified_iswa * inp,
1493
+ llm_graph_input_attn_kv_iswa * inp,
1517
1494
  ggml_tensor * wo,
1518
1495
  ggml_tensor * wo_b,
1519
1496
  ggml_tensor * q_cur,
1520
1497
  ggml_tensor * k_cur,
1521
1498
  ggml_tensor * v_cur,
1522
1499
  ggml_tensor * kq_b,
1523
- ggml_tensor * v_mla,
1524
1500
  ggml_tensor * sinks,
1501
+ ggml_tensor * v_mla,
1525
1502
  float kq_scale,
1526
1503
  int il) const {
1527
1504
  // these nodes are added to the graph together so that they are not reordered
@@ -1561,7 +1538,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks(
1561
1538
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1562
1539
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1563
1540
 
1564
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
1541
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1565
1542
  cb(cur, "kqv_out", il);
1566
1543
 
1567
1544
  if (wo) {
@@ -1600,6 +1577,7 @@ ggml_tensor * llm_graph_context::build_attn(
1600
1577
  ggml_tensor * k_cur,
1601
1578
  ggml_tensor * v_cur,
1602
1579
  ggml_tensor * kq_b,
1580
+ ggml_tensor * sinks,
1603
1581
  ggml_tensor * v_mla,
1604
1582
  float kq_scale,
1605
1583
  int il) const {
@@ -1615,7 +1593,7 @@ ggml_tensor * llm_graph_context::build_attn(
1615
1593
  ggml_tensor * k = k_cur;
1616
1594
  ggml_tensor * v = v_cur;
1617
1595
 
1618
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1596
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1619
1597
  cb(cur, "kqv_out", il);
1620
1598
 
1621
1599
  if (wo) {
@@ -1636,10 +1614,10 @@ ggml_tensor * llm_graph_context::build_attn(
1636
1614
  // TODO: maybe separate the inner implementation into a separate function
1637
1615
  // like with the non-sliding window equivalent
1638
1616
  // once sliding-window hybrid caches are a thing.
1639
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1640
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1617
+ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
1618
+ const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
1641
1619
 
1642
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1620
+ auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1643
1621
 
1644
1622
  const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1645
1623
 
@@ -1656,7 +1634,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1656
1634
  }
1657
1635
 
1658
1636
  {
1659
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1637
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1660
1638
 
1661
1639
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1662
1640
 
@@ -1669,7 +1647,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1669
1647
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1670
1648
  }
1671
1649
 
1672
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1650
+ return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
1673
1651
  }
1674
1652
 
1675
1653
  ggml_tensor * llm_graph_context::build_rs(
@@ -1792,7 +1770,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1792
1770
  const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1793
1771
 
1794
1772
  auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
1795
- auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1773
+ auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1796
1774
 
1797
1775
  auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1798
1776