@fugood/llama.node 1.1.5 → 1.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. package/lib/binding.ts +4 -0
  2. package/lib/index.js +6 -1
  3. package/lib/index.ts +6 -0
  4. package/lib/version.js +5 -0
  5. package/lib/version.ts +2 -0
  6. package/package.json +14 -14
  7. package/scripts/llama.cpp.patch +19 -15
  8. package/src/LlamaCompletionWorker.cpp +73 -18
  9. package/src/LlamaCompletionWorker.h +8 -0
  10. package/src/llama.cpp/CMakeLists.txt +2 -0
  11. package/src/llama.cpp/common/arg.cpp +147 -46
  12. package/src/llama.cpp/common/chat-parser.cpp +9 -1
  13. package/src/llama.cpp/common/chat.cpp +350 -3
  14. package/src/llama.cpp/common/chat.h +11 -3
  15. package/src/llama.cpp/common/common.cpp +54 -0
  16. package/src/llama.cpp/common/common.h +44 -9
  17. package/src/llama.cpp/ggml/CMakeLists.txt +5 -2
  18. package/src/llama.cpp/ggml/include/ggml-opt.h +25 -6
  19. package/src/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  20. package/src/llama.cpp/ggml/include/ggml.h +65 -3
  21. package/src/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  22. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +1 -1
  23. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +61 -0
  24. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +96 -8
  25. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +1136 -1077
  26. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +20 -0
  27. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +20 -1
  28. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  29. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  30. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +270 -11
  31. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +3 -8
  32. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +200 -51
  35. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +11 -0
  36. package/src/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  37. package/src/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  38. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +19 -4
  39. package/src/llama.cpp/include/llama.h +26 -0
  40. package/src/llama.cpp/src/llama-arch.cpp +65 -0
  41. package/src/llama.cpp/src/llama-arch.h +10 -0
  42. package/src/llama.cpp/src/llama-batch.cpp +1 -1
  43. package/src/llama.cpp/src/llama-chat.cpp +15 -4
  44. package/src/llama.cpp/src/llama-chat.h +1 -0
  45. package/src/llama.cpp/src/llama-context.cpp +37 -25
  46. package/src/llama.cpp/src/llama-context.h +6 -5
  47. package/src/llama.cpp/src/llama-graph.cpp +118 -9
  48. package/src/llama.cpp/src/llama-graph.h +38 -0
  49. package/src/llama.cpp/src/llama-hparams.h +5 -3
  50. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +12 -6
  51. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +2 -2
  52. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +93 -69
  53. package/src/llama.cpp/src/llama-kv-cache-unified.h +2 -2
  54. package/src/llama.cpp/src/llama-memory-hybrid.cpp +6 -2
  55. package/src/llama.cpp/src/llama-memory-hybrid.h +2 -2
  56. package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -2
  57. package/src/llama.cpp/src/llama-memory-recurrent.h +2 -2
  58. package/src/llama.cpp/src/llama-memory.h +2 -2
  59. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  60. package/src/llama.cpp/src/llama-model-loader.h +3 -2
  61. package/src/llama.cpp/src/llama-model.cpp +500 -4
  62. package/src/llama.cpp/src/llama-model.h +25 -4
  63. package/src/llama.cpp/src/llama-quant.cpp +37 -1
  64. package/src/llama.cpp/src/llama-vocab.cpp +43 -0
@@ -477,7 +477,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
477
477
 
478
478
  llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
479
479
  if (sequential && has_cpl) {
480
- LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
480
+ LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
481
481
 
482
482
  return {};
483
483
  }
@@ -66,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
66
66
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
67
67
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
68
68
  { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
69
+ { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
69
70
  { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
70
71
  { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
71
72
  };
@@ -192,9 +193,11 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
192
193
  return LLM_CHAT_TEMPLATE_LLAMA4;
193
194
  } else if (tmpl_contains("<|endofuserprompt|>")) {
194
195
  return LLM_CHAT_TEMPLATE_DOTS1;
195
- } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
196
+ } else if (tmpl_contains("<|extra_0|>") && tmpl_contains("<|extra_4|>")) {
196
197
  return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
197
- } else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) {
198
+ } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
199
+ return LLM_CHAT_TEMPLATE_OPENAI_MOE;
200
+ } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) {
198
201
  return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
199
202
  } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
200
203
  return LLM_CHAT_TEMPLATE_KIMI_K2;
@@ -622,8 +625,6 @@ int32_t llm_chat_apply_template(
622
625
  } else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
623
626
  // Yandex template ("\n\n" is defined as EOT token)
624
627
 
625
- ss << "<s>";
626
-
627
628
  for (size_t i = 0; i < chat.size(); i++) {
628
629
  std::string role(chat[i]->role);
629
630
  if (role == "user") {
@@ -706,6 +707,16 @@ int32_t llm_chat_apply_template(
706
707
  ss << "<|startoftext|>" << message->content << "<|extra_0|>";
707
708
  }
708
709
  }
710
+ } else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
711
+ // OpenAI MoE (based on Harmony chat template)
712
+ for (auto message : chat) {
713
+ std::string role(message->role);
714
+ ss << "<|start|>" << role << "<|message|>" << message->content;
715
+ ss << (role == "assistant" ? "<|return|>" : "<|end|>");
716
+ }
717
+ if (add_ass) {
718
+ ss << "<|start|>assistant";
719
+ }
709
720
  } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) {
710
721
  // tencent/Hunyuan-4B-Instruct
711
722
  for (size_t i = 0; i < chat.size(); i++) {
@@ -46,6 +46,7 @@ enum llm_chat_template {
46
46
  LLM_CHAT_TEMPLATE_SMOLVLM,
47
47
  LLM_CHAT_TEMPLATE_DOTS1,
48
48
  LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
49
+ LLM_CHAT_TEMPLATE_OPENAI_MOE,
49
50
  LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
50
51
  LLM_CHAT_TEMPLATE_KIMI_K2,
51
52
  LLM_CHAT_TEMPLATE_UNKNOWN,
@@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
786
786
  const auto & hparams = model.hparams;
787
787
 
788
788
  const int64_t n_embd = hparams.n_embd;
789
- const int32_t n_vocab = model.vocab.n_tokens();
789
+ const int64_t n_vocab = model.vocab.n_tokens();
790
790
 
791
791
  // note: during encode, we always pass the full sequence starting from pos = 0
792
792
  if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
@@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
959
959
  const auto & vocab = model.vocab;
960
960
  const auto & hparams = model.hparams;
961
961
 
962
- const int32_t n_vocab = vocab.n_tokens();
962
+ const int64_t n_vocab = vocab.n_tokens();
963
963
  const int64_t n_embd = hparams.n_embd;
964
964
 
965
965
  // when computing embeddings, all tokens are output
@@ -1328,21 +1328,21 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1328
1328
  }
1329
1329
 
1330
1330
  void llama_context::output_reorder() {
1331
- const uint32_t n_vocab = model.vocab.n_tokens();
1331
+ const uint64_t n_vocab = model.vocab.n_tokens();
1332
1332
  const uint64_t n_embd = model.hparams.n_embd;
1333
1333
 
1334
- for (uint32_t s = 0; s < output_swaps.size(); ++s) {
1335
- const uint32_t i0 = output_swaps[s].i0;
1336
- const uint32_t i1 = output_swaps[s].i1;
1334
+ for (size_t s = 0; s < output_swaps.size(); ++s) {
1335
+ const uint64_t i0 = output_swaps[s].i0;
1336
+ const uint64_t i1 = output_swaps[s].i1;
1337
1337
 
1338
1338
  if (logits_size > 0) {
1339
- for (uint32_t k = 0; k < n_vocab; k++) {
1339
+ for (uint64_t k = 0; k < n_vocab; k++) {
1340
1340
  std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1341
1341
  }
1342
1342
  }
1343
1343
 
1344
1344
  if (embd_size > 0) {
1345
- for (uint32_t k = 0; k < n_embd; k++) {
1345
+ for (uint64_t k = 0; k < n_embd; k++) {
1346
1346
  std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1347
1347
  }
1348
1348
  }
@@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1657
1657
  }
1658
1658
  }
1659
1659
 
1660
- size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
1660
+ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
1661
1661
  llama_io_write_dummy io;
1662
1662
  try {
1663
- return state_seq_write_data(io, seq_id);
1663
+ return state_seq_write_data(io, seq_id, flags);
1664
1664
  } catch (const std::exception & err) {
1665
1665
  LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1666
1666
  return 0;
1667
1667
  }
1668
1668
  }
1669
1669
 
1670
- size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
1670
+ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
1671
1671
  llama_io_write_buffer io(dst, size);
1672
1672
  try {
1673
- return state_seq_write_data(io, seq_id);
1673
+ return state_seq_write_data(io, seq_id, flags);
1674
1674
  } catch (const std::exception & err) {
1675
1675
  LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1676
1676
  return 0;
1677
1677
  }
1678
1678
  }
1679
1679
 
1680
- size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
1680
+ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
1681
1681
  llama_io_read_buffer io(src, size);
1682
1682
  try {
1683
- return state_seq_read_data(io, seq_id);
1683
+ return state_seq_read_data(io, seq_id, flags);
1684
1684
  } catch (const std::exception & err) {
1685
1685
  LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1686
1686
  return 0;
@@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
1778
1778
  {
1779
1779
  const size_t state_size = file.size() - file.tell();
1780
1780
  llama_io_read_file io(&file);
1781
- const size_t nread = state_seq_read_data(io, seq_id);
1781
+ const size_t nread = state_seq_read_data(io, seq_id, 0);
1782
1782
  if (!nread) {
1783
1783
  LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1784
1784
  return 0;
@@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
1802
1802
 
1803
1803
  // save the context state using stream saving
1804
1804
  llama_io_write_file io(&file);
1805
- state_seq_write_data(io, seq_id);
1805
+ state_seq_write_data(io, seq_id, 0);
1806
1806
 
1807
1807
  const size_t res = file.tell();
1808
1808
  GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1971
1971
  return io.n_bytes();
1972
1972
  }
1973
1973
 
1974
- size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
1974
+ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1975
1975
  GGML_UNUSED(seq_id);
1976
1976
 
1977
1977
  if (memory) {
1978
- memory->state_write(io, seq_id);
1978
+ memory->state_write(io, seq_id, flags);
1979
1979
  }
1980
1980
 
1981
1981
  return io.n_bytes();
1982
1982
  }
1983
1983
 
1984
- size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
1984
+ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1985
1985
  GGML_UNUSED(seq_id);
1986
1986
 
1987
1987
  if (memory) {
1988
- memory->state_read(io, seq_id);
1988
+ memory->state_read(io, seq_id, flags);
1989
1989
  }
1990
1990
 
1991
1991
  return io.n_bytes();
@@ -2048,7 +2048,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
2048
2048
  opt_params.opt_period = n_batch / n_ubatch;
2049
2049
  opt_params.get_opt_pars = lopt_params.get_opt_pars;
2050
2050
  opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
2051
-
2051
+ opt_params.optimizer = lopt_params.optimizer_type;
2052
2052
  opt_ctx = ggml_opt_init(opt_params);
2053
2053
 
2054
2054
  llama_opt_param_filter param_filter = lopt_params.param_filter;
@@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
2801
2801
  }
2802
2802
 
2803
2803
  size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2804
- return ctx->state_seq_get_size(seq_id);
2804
+ return llama_state_seq_get_size_ext(ctx, seq_id, 0);
2805
2805
  }
2806
2806
 
2807
2807
  size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2808
+ return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
2809
+ }
2810
+
2811
+ size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2812
+ return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
2813
+ }
2814
+
2815
+ size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
2816
+ return ctx->state_seq_get_size(seq_id, flags);
2817
+ }
2818
+
2819
+ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2808
2820
  ctx->synchronize();
2809
2821
 
2810
- return ctx->state_seq_get_data(seq_id, dst, size);
2822
+ return ctx->state_seq_get_data(seq_id, dst, size, flags);
2811
2823
  }
2812
2824
 
2813
- size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2825
+ size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2814
2826
  ctx->synchronize();
2815
2827
 
2816
- return ctx->state_seq_set_data(seq_id, src, size);
2828
+ return ctx->state_seq_set_data(seq_id, src, size, flags);
2817
2829
  }
2818
2830
 
2819
2831
  size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
@@ -111,9 +111,9 @@ struct llama_context {
111
111
  size_t state_get_data( uint8_t * dst, size_t size);
112
112
  size_t state_set_data(const uint8_t * src, size_t size);
113
113
 
114
- size_t state_seq_get_size(llama_seq_id seq_id);
115
- size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
116
- size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
114
+ size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
115
+ size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
116
+ size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
117
117
 
118
118
  bool state_load_file(
119
119
  const char * filepath,
@@ -152,6 +152,7 @@ struct llama_context {
152
152
 
153
153
  void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
154
154
 
155
+ // TODO: more flexible combinations of logical/physical batch size and context size
155
156
  void opt_epoch(
156
157
  ggml_opt_dataset_t dataset,
157
158
  ggml_opt_result_t result_train,
@@ -212,8 +213,8 @@ private:
212
213
  size_t state_write_data(llama_io_write_i & io);
213
214
  size_t state_read_data (llama_io_read_i & io);
214
215
 
215
- size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
216
- size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
216
+ size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
217
+ size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
217
218
 
218
219
  //
219
220
  // members
@@ -740,6 +740,8 @@ ggml_tensor * llm_graph_context::build_ffn(
740
740
  cur = ggml_reglu(ctx0, cur);
741
741
  cb(cur, "ffn_reglu", il);
742
742
  } break;
743
+ default:
744
+ GGML_ABORT("fatal error");
743
745
  }
744
746
 
745
747
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -749,8 +751,8 @@ ggml_tensor * llm_graph_context::build_ffn(
749
751
 
750
752
  if (down) {
751
753
  cur = build_lora_mm(down, cur);
752
- if (arch == LLM_ARCH_GLM4) {
753
- // GLM4 seems to have numerical issues with half-precision accumulators
754
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
755
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
754
756
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
755
757
  }
756
758
  }
@@ -787,6 +789,45 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
787
789
  llama_expert_gating_func_type gating_op,
788
790
  int il,
789
791
  ggml_tensor * probs_in) const {
792
+ return build_moe_ffn(
793
+ cur,
794
+ gate_inp, /* gate_inp_b */ nullptr,
795
+ up_exps, /* up_exps_b */ nullptr,
796
+ gate_exps, /* gate_exps_b */ nullptr,
797
+ down_exps, /* down_exps_b */ nullptr,
798
+ exp_probs_b,
799
+ n_expert,
800
+ n_expert_used,
801
+ type_op,
802
+ norm_w,
803
+ scale_w,
804
+ w_scale,
805
+ gating_op,
806
+ il,
807
+ probs_in
808
+ );
809
+ }
810
+
811
+ ggml_tensor * llm_graph_context::build_moe_ffn(
812
+ ggml_tensor * cur,
813
+ ggml_tensor * gate_inp,
814
+ ggml_tensor * gate_inp_b,
815
+ ggml_tensor * up_exps,
816
+ ggml_tensor * up_exps_b,
817
+ ggml_tensor * gate_exps,
818
+ ggml_tensor * gate_exps_b,
819
+ ggml_tensor * down_exps,
820
+ ggml_tensor * down_exps_b,
821
+ ggml_tensor * exp_probs_b,
822
+ int64_t n_expert,
823
+ int64_t n_expert_used,
824
+ llm_ffn_op_type type_op,
825
+ bool norm_w,
826
+ bool scale_w,
827
+ float w_scale,
828
+ llama_expert_gating_func_type gating_op,
829
+ int il,
830
+ ggml_tensor * probs_in) const {
790
831
  const int64_t n_embd = cur->ne[0];
791
832
  const int64_t n_tokens = cur->ne[1];
792
833
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -800,6 +841,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
800
841
  logits = probs_in;
801
842
  }
802
843
 
844
+ if (gate_inp_b) {
845
+ logits = ggml_add(ctx0, logits, gate_inp_b);
846
+ cb(logits, "ffn_moe_logits_biased", il);
847
+ }
848
+
803
849
  ggml_tensor * probs = nullptr;
804
850
  switch (gating_op) {
805
851
  case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
@@ -810,6 +856,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
810
856
  {
811
857
  probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
812
858
  } break;
859
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
860
+ {
861
+ probs = logits; // [n_expert, n_tokens]
862
+ } break;
813
863
  default:
814
864
  GGML_ABORT("fatal error");
815
865
  }
@@ -838,6 +888,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
838
888
  ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
839
889
  cb(weights, "ffn_moe_weights", il);
840
890
 
891
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
892
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
893
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
894
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
895
+ cb(weights, "ffn_moe_weights_softmax", il);
896
+ }
897
+
841
898
  if (norm_w) {
842
899
  weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
843
900
 
@@ -866,6 +923,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
866
923
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
867
924
  cb(up, "ffn_moe_up", il);
868
925
 
926
+ if (up_exps_b) {
927
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
928
+ cb(up, "ffn_moe_up_biased", il);
929
+ }
930
+
869
931
  ggml_tensor * experts = nullptr;
870
932
  if (gate_exps) {
871
933
  cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
@@ -874,6 +936,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
874
936
  cur = up;
875
937
  }
876
938
 
939
+ if (gate_exps_b) {
940
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
941
+ cb(cur, "ffn_moe_gate_biased", il);
942
+ }
943
+
877
944
  switch (type_op) {
878
945
  case LLM_FFN_SILU:
879
946
  if (gate_exps) {
@@ -891,6 +958,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
891
958
  cur = ggml_gelu(ctx0, cur);
892
959
  cb(cur, "ffn_moe_gelu", il);
893
960
  } break;
961
+ case LLM_FFN_SWIGLU_OAI_MOE:
962
+ {
963
+ // TODO: move to hparams?
964
+ constexpr float alpha = 1.702f;
965
+ constexpr float limit = 7.0f;
966
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
967
+ cb(cur, "ffn_moe_swiglu_oai", il);
968
+ } break;
894
969
  case LLM_FFN_RELU:
895
970
  if (gate_exps) {
896
971
  cur = ggml_reglu_split(ctx0, cur, up);
@@ -906,6 +981,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
906
981
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
907
982
  cb(experts, "ffn_moe_down", il);
908
983
 
984
+ if (down_exps_b) {
985
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
986
+ cb(experts, "ffn_moe_down_biased", il);
987
+ }
988
+
909
989
  if (!weight_before_ffn) {
910
990
  experts = ggml_mul(ctx0, experts, weights);
911
991
  cb(cur, "ffn_moe_weighted", il);
@@ -1144,6 +1224,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1144
1224
  ggml_tensor * kq_b,
1145
1225
  ggml_tensor * kq_mask,
1146
1226
  ggml_tensor * v_mla,
1227
+ ggml_tensor * sinks,
1147
1228
  float kq_scale) const {
1148
1229
  const bool v_trans = v->nb[1] > v->nb[2];
1149
1230
 
@@ -1180,7 +1261,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1180
1261
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1181
1262
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1182
1263
 
1183
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1264
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
1265
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1184
1266
 
1185
1267
  if (v_mla) {
1186
1268
  #if 0
@@ -1228,6 +1310,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1228
1310
  }
1229
1311
 
1230
1312
  kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1313
+ ggml_soft_max_add_sinks(kq, sinks);
1231
1314
 
1232
1315
  if (!v_trans) {
1233
1316
  // note: avoid this branch
@@ -1298,7 +1381,7 @@ ggml_tensor * llm_graph_context::build_attn(
1298
1381
  ggml_tensor * k = k_cur;
1299
1382
  ggml_tensor * v = v_cur;
1300
1383
 
1301
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1384
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1302
1385
  cb(cur, "kqv_out", il);
1303
1386
 
1304
1387
  if (wo) {
@@ -1386,13 +1469,13 @@ ggml_tensor * llm_graph_context::build_attn(
1386
1469
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1387
1470
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1388
1471
 
1389
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1472
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1390
1473
  cb(cur, "kqv_out", il);
1391
1474
 
1392
1475
  if (wo) {
1393
1476
  cur = build_lora_mm(wo, cur);
1394
- if (arch == LLM_ARCH_GLM4) {
1395
- // GLM4 seems to have numerical issues with half-precision accumulators
1477
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1478
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1396
1479
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1397
1480
  }
1398
1481
  }
@@ -1415,6 +1498,32 @@ ggml_tensor * llm_graph_context::build_attn(
1415
1498
  ggml_tensor * v_mla,
1416
1499
  float kq_scale,
1417
1500
  int il) const {
1501
+ return build_attn_with_sinks(
1502
+ inp,
1503
+ wo,
1504
+ wo_b,
1505
+ q_cur,
1506
+ k_cur,
1507
+ v_cur,
1508
+ kq_b,
1509
+ v_mla,
1510
+ nullptr,
1511
+ kq_scale,
1512
+ il);
1513
+ }
1514
+
1515
+ ggml_tensor * llm_graph_context::build_attn_with_sinks(
1516
+ llm_graph_input_attn_kv_unified_iswa * inp,
1517
+ ggml_tensor * wo,
1518
+ ggml_tensor * wo_b,
1519
+ ggml_tensor * q_cur,
1520
+ ggml_tensor * k_cur,
1521
+ ggml_tensor * v_cur,
1522
+ ggml_tensor * kq_b,
1523
+ ggml_tensor * v_mla,
1524
+ ggml_tensor * sinks,
1525
+ float kq_scale,
1526
+ int il) const {
1418
1527
  // these nodes are added to the graph together so that they are not reordered
1419
1528
  // by doing so, the number of splits in the graph is reduced
1420
1529
  ggml_build_forward_expand(gf, q_cur);
@@ -1452,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn(
1452
1561
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1453
1562
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1454
1563
 
1455
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1564
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
1456
1565
  cb(cur, "kqv_out", il);
1457
1566
 
1458
1567
  if (wo) {
@@ -1506,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_attn(
1506
1615
  ggml_tensor * k = k_cur;
1507
1616
  ggml_tensor * v = v_cur;
1508
1617
 
1509
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1618
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1510
1619
  cb(cur, "kqv_out", il);
1511
1620
 
1512
1621
  if (wo) {
@@ -39,6 +39,7 @@ enum llm_ffn_op_type {
39
39
  LLM_FFN_SWIGLU,
40
40
  LLM_FFN_GEGLU,
41
41
  LLM_FFN_REGLU,
42
+ LLM_FFN_SWIGLU_OAI_MOE,
42
43
  };
43
44
 
44
45
  enum llm_ffn_gate_type {
@@ -619,6 +620,7 @@ struct llm_graph_context {
619
620
  llm_ffn_gate_type type_gate,
620
621
  int il) const;
621
622
 
623
+ // build MoE FFN without bias tensors
622
624
  ggml_tensor * build_moe_ffn(
623
625
  ggml_tensor * cur,
624
626
  ggml_tensor * gate_inp,
@@ -636,6 +638,27 @@ struct llm_graph_context {
636
638
  int il,
637
639
  ggml_tensor * probs_in = nullptr) const;
638
640
 
641
+ ggml_tensor * build_moe_ffn(
642
+ ggml_tensor * cur,
643
+ ggml_tensor * gate_inp,
644
+ ggml_tensor * gate_inp_b,
645
+ ggml_tensor * up_exps,
646
+ ggml_tensor * up_exps_b,
647
+ ggml_tensor * gate_exps,
648
+ ggml_tensor * gate_exps_b,
649
+ ggml_tensor * down_exps,
650
+ ggml_tensor * down_exps_b,
651
+ ggml_tensor * exp_probs_b,
652
+ int64_t n_expert,
653
+ int64_t n_expert_used,
654
+ llm_ffn_op_type type_op,
655
+ bool norm_w,
656
+ bool scale_w,
657
+ float w_scale,
658
+ llama_expert_gating_func_type gating_op,
659
+ int il,
660
+ ggml_tensor * probs_in = nullptr) const;
661
+
639
662
  //
640
663
  // inputs
641
664
  //
@@ -662,6 +685,7 @@ struct llm_graph_context {
662
685
  ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
663
686
  ggml_tensor * kq_b,
664
687
  ggml_tensor * kq_mask,
688
+ ggml_tensor * sinks,
665
689
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
666
690
  float kq_scale) const;
667
691
 
@@ -708,6 +732,20 @@ struct llm_graph_context {
708
732
  float kq_scale,
709
733
  int il) const;
710
734
 
735
+ // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
736
+ ggml_tensor * build_attn_with_sinks(
737
+ llm_graph_input_attn_kv_unified_iswa * inp,
738
+ ggml_tensor * wo,
739
+ ggml_tensor * wo_b,
740
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
741
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
742
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
743
+ ggml_tensor * kq_b,
744
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
745
+ ggml_tensor * sinks, // [n_head_q]
746
+ float kq_scale,
747
+ int il) const;
748
+
711
749
  llm_graph_input_attn_cross * build_attn_inp_cross() const;
712
750
 
713
751
  ggml_tensor * build_attn(
@@ -9,9 +9,10 @@
9
9
  #define LLAMA_MAX_EXPERTS 384 // Kimi-K2
10
10
 
11
11
  enum llama_expert_gating_func_type {
12
- LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
13
- LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
14
- LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
12
+ LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
13
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
14
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
15
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits
15
16
  };
16
17
 
17
18
  enum llama_swa_type {
@@ -73,6 +74,7 @@ struct llama_hparams {
73
74
  bool expert_weights_norm = false;
74
75
  uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
75
76
  uint32_t moe_every_n_layers = 0;
77
+ uint32_t nextn_predict_layers = 0;
76
78
 
77
79
  float f_norm_eps;
78
80
  float f_norm_rms_eps;
@@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
194
194
  return kv_base->get_size() == kv_swa->get_size();
195
195
  }
196
196
 
197
- void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
198
- kv_base->state_write(io, seq_id);
199
- kv_swa ->state_write(io, seq_id);
197
+ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
198
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
199
+ kv_base->state_write(io, seq_id, flags);
200
+ }
201
+
202
+ kv_swa->state_write(io, seq_id, flags);
200
203
  }
201
204
 
202
- void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
203
- kv_base->state_read(io, seq_id);
204
- kv_swa ->state_read(io, seq_id);
205
+ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
206
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
207
+ kv_base->state_read(io, seq_id, flags);
208
+ }
209
+
210
+ kv_swa->state_read(io, seq_id, flags);
205
211
  }
206
212
 
207
213
  llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
@@ -56,8 +56,8 @@ public:
56
56
 
57
57
  // state write/load
58
58
 
59
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
60
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
59
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
60
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
61
61
 
62
62
  //
63
63
  // llama_kv_cache_unified_iswa specific API