@fugood/llama.node 1.1.9 → 1.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/lib/binding.ts CHANGED
@@ -100,6 +100,11 @@ export type LlamaCompletionOptions = {
100
100
  enable_thinking?: boolean
101
101
  thinking_forced_open?: boolean
102
102
  prompt?: string
103
+ /**
104
+ * Text to prefill the response with.
105
+ * This text will be added to the beginning of the generated response.
106
+ */
107
+ prefill_text?: string
103
108
  temperature?: number
104
109
  top_k?: number
105
110
  top_p?: number
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.1.9",
4
+ "version": "1.1.10",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
@@ -71,19 +71,19 @@
71
71
  "CMakeLists.txt"
72
72
  ],
73
73
  "optionalDependencies": {
74
- "@fugood/node-llama-linux-x64": "1.1.9",
75
- "@fugood/node-llama-linux-x64-vulkan": "1.1.9",
76
- "@fugood/node-llama-linux-x64-cuda": "1.1.9",
77
- "@fugood/node-llama-linux-arm64": "1.1.9",
78
- "@fugood/node-llama-linux-arm64-vulkan": "1.1.9",
79
- "@fugood/node-llama-linux-arm64-cuda": "1.1.9",
80
- "@fugood/node-llama-win32-x64": "1.1.9",
81
- "@fugood/node-llama-win32-x64-vulkan": "1.1.9",
82
- "@fugood/node-llama-win32-x64-cuda": "1.1.9",
83
- "@fugood/node-llama-win32-arm64": "1.1.9",
84
- "@fugood/node-llama-win32-arm64-vulkan": "1.1.9",
85
- "@fugood/node-llama-darwin-x64": "1.1.9",
86
- "@fugood/node-llama-darwin-arm64": "1.1.9"
74
+ "@fugood/node-llama-linux-x64": "1.1.10",
75
+ "@fugood/node-llama-linux-x64-vulkan": "1.1.10",
76
+ "@fugood/node-llama-linux-x64-cuda": "1.1.10",
77
+ "@fugood/node-llama-linux-arm64": "1.1.10",
78
+ "@fugood/node-llama-linux-arm64-vulkan": "1.1.10",
79
+ "@fugood/node-llama-linux-arm64-cuda": "1.1.10",
80
+ "@fugood/node-llama-win32-x64": "1.1.10",
81
+ "@fugood/node-llama-win32-x64-vulkan": "1.1.10",
82
+ "@fugood/node-llama-win32-x64-cuda": "1.1.10",
83
+ "@fugood/node-llama-win32-arm64": "1.1.10",
84
+ "@fugood/node-llama-win32-arm64-vulkan": "1.1.10",
85
+ "@fugood/node-llama-darwin-x64": "1.1.10",
86
+ "@fugood/node-llama-darwin-arm64": "1.1.10"
87
87
  },
88
88
  "devDependencies": {
89
89
  "@babel/preset-env": "^7.24.4",
@@ -1,5 +1,5 @@
1
1
  diff --git a/src/llama.cpp/common/chat.cpp b/src/llama.cpp/common/chat.cpp
2
- index 23d3828f9..ca48af00c 100644
2
+ index 111b4a21b..16ce87672 100644
3
3
  --- a/src/llama.cpp/common/chat.cpp
4
4
  +++ b/src/llama.cpp/common/chat.cpp
5
5
  @@ -6,9 +6,6 @@
@@ -29,6 +29,16 @@ index 23d3828f9..ca48af00c 100644
29
29
  struct templates_params {
30
30
  json messages;
31
31
  json tools;
32
+ @@ -784,8 +771,7 @@ static std::string apply(
33
+ if (additional_context) {
34
+ tmpl_inputs.extra_context.merge_patch(*additional_context);
35
+ }
36
+ - // TODO: add flag to control date/time, if only for testing purposes.
37
+ - // tmpl_inputs.now = std::chrono::system_clock::now();
38
+ + tmpl_inputs.now = inputs.now;
39
+
40
+ minja::chat_template_options tmpl_opts;
41
+ // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
32
42
  diff --git a/src/llama.cpp/common/chat.h b/src/llama.cpp/common/chat.h
33
43
  index d1e480c91..437e64e29 100644
34
44
  --- a/src/llama.cpp/common/chat.h
@@ -54,10 +64,10 @@ index d1e480c91..437e64e29 100644
54
64
  struct common_chat_tool_call {
55
65
  std::string name;
56
66
  diff --git a/src/llama.cpp/common/common.cpp b/src/llama.cpp/common/common.cpp
57
- index 67dd5404f..909a97c66 100644
67
+ index fdce1dcde..55aac3412 100644
58
68
  --- a/src/llama.cpp/common/common.cpp
59
69
  +++ b/src/llama.cpp/common/common.cpp
60
- @@ -1117,6 +1117,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
70
+ @@ -1103,6 +1103,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
61
71
  mparams.n_gpu_layers = params.n_gpu_layers;
62
72
  }
63
73
 
@@ -66,10 +76,10 @@ index 67dd5404f..909a97c66 100644
66
76
  mparams.split_mode = params.split_mode;
67
77
  mparams.tensor_split = params.tensor_split;
68
78
  diff --git a/src/llama.cpp/common/common.h b/src/llama.cpp/common/common.h
69
- index 75596e6b3..0e04694c8 100644
79
+ index 390dda5e5..f259ca785 100644
70
80
  --- a/src/llama.cpp/common/common.h
71
81
  +++ b/src/llama.cpp/common/common.h
72
- @@ -267,6 +267,7 @@ struct lr_opt {
82
+ @@ -270,6 +270,7 @@ struct lr_opt {
73
83
  struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
74
84
 
75
85
  struct common_params {
@@ -35,12 +35,14 @@ LlamaCompletionWorker::LlamaCompletionWorker(
35
35
  const std::vector<std::string> &media_paths,
36
36
  const std::vector<llama_token> &guide_tokens,
37
37
  bool has_vocoder,
38
- tts_type tts_type_val)
38
+ tts_type tts_type_val,
39
+ const std::string &prefill_text)
39
40
  : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess),
40
41
  _params(params), _stop_words(stop_words), _chat_format(chat_format),
41
42
  _thinking_forced_open(thinking_forced_open),
42
43
  _reasoning_format(reasoning_format),
43
44
  _media_paths(media_paths), _guide_tokens(guide_tokens),
45
+ _prefill_text(prefill_text),
44
46
  _has_vocoder(has_vocoder), _tts_type(tts_type_val) {
45
47
  if (!callback.IsEmpty()) {
46
48
  _tsfn = Napi::ThreadSafeFunction::New(info.Env(), callback,
@@ -68,8 +70,11 @@ LlamaCompletionWorker::PartialOutput LlamaCompletionWorker::getPartialOutput(con
68
70
 
69
71
  chat_syntax.parse_tool_calls = true;
70
72
 
73
+ // Combine prefill_text with generated_text for parsing
74
+ std::string full_text = _prefill_text + generated_text;
75
+
71
76
  // Use is_partial=true for streaming partial output
72
- common_chat_msg parsed_msg = common_chat_parse(generated_text, true, chat_syntax);
77
+ common_chat_msg parsed_msg = common_chat_parse(full_text, true, chat_syntax);
73
78
 
74
79
  result.content = parsed_msg.content;
75
80
  result.reasoning_content = parsed_msg.reasoning_content;
@@ -156,6 +161,7 @@ void LlamaCompletionWorker::Execute() {
156
161
  auto embd = _sess->tokens_ptr();
157
162
  embd->reserve(embd->size() + max_len);
158
163
 
164
+
159
165
  if (is_enc_dec) {
160
166
  if (n_input > 0) {
161
167
  // Decode tokens in batches using n_batch as chunk size
@@ -378,8 +384,11 @@ void LlamaCompletionWorker::OnOK() {
378
384
  chat_syntax.thinking_forced_open = _thinking_forced_open;
379
385
 
380
386
  chat_syntax.reasoning_format = common_reasoning_format_from_name(_reasoning_format);
387
+
388
+ // Combine prefill_text with generated_text for final parsing
389
+ std::string full_text = _prefill_text + _result.text;
381
390
  common_chat_msg message = common_chat_parse(
382
- _result.text,
391
+ full_text,
383
392
  false,
384
393
  chat_syntax
385
394
  );
@@ -26,7 +26,8 @@ public:
26
26
  const std::vector<std::string> &media_paths = {},
27
27
  const std::vector<llama_token> &guide_tokens = {},
28
28
  bool has_vocoder = false,
29
- tts_type tts_type_val = UNKNOWN);
29
+ tts_type tts_type_val = UNKNOWN,
30
+ const std::string &prefill_text = "");
30
31
 
31
32
  ~LlamaCompletionWorker();
32
33
 
@@ -58,6 +59,7 @@ private:
58
59
  std::string _reasoning_format;
59
60
  std::vector<std::string> _media_paths;
60
61
  std::vector<llama_token> _guide_tokens;
62
+ std::string _prefill_text;
61
63
  std::function<void()> _onComplete;
62
64
  bool _has_callback = false;
63
65
  bool _interrupted = false;
@@ -935,6 +935,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
935
935
  json_schema_to_grammar(json::parse(json_schema_str));
936
936
  }
937
937
 
938
+ std::string prefill_text = get_option<std::string>(options, "prefill_text", "");
939
+
938
940
  params.n_predict = get_option<int32_t>(options, "n_predict", -1);
939
941
  params.sampling.temp = get_option<float>(options, "temperature", 0.80f);
940
942
  params.sampling.top_k = get_option<int32_t>(options, "top_k", 40);
@@ -1007,7 +1009,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
1007
1009
  auto *worker =
1008
1010
  new LlamaCompletionWorker(info, _sess, callback, params, stop_words,
1009
1011
  chat_format, thinking_forced_open, reasoning_format, media_paths, guide_tokens,
1010
- _has_vocoder, _tts_type);
1012
+ _has_vocoder, _tts_type, prefill_text);
1011
1013
  worker->Queue();
1012
1014
  _wip = worker;
1013
1015
  worker->OnComplete([this]() { _wip = nullptr; });
@@ -771,8 +771,7 @@ static std::string apply(
771
771
  if (additional_context) {
772
772
  tmpl_inputs.extra_context.merge_patch(*additional_context);
773
773
  }
774
- // TODO: add flag to control date/time, if only for testing purposes.
775
- // tmpl_inputs.now = std::chrono::system_clock::now();
774
+ tmpl_inputs.now = inputs.now;
776
775
 
777
776
  minja::chat_template_options tmpl_opts;
778
777
  // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
153
153
 
154
154
  GGML_ABORT("fatal error");
155
155
  }
156
+
157
+ bool llama_hparams::has_kv(uint32_t il) const {
158
+ if (n_layer_kv_from_start >= 0) {
159
+ if (il < (uint32_t) n_layer_kv_from_start) {
160
+ return true;
161
+ }
162
+
163
+ return false;
164
+ }
165
+
166
+ // by default, all layers have kv
167
+ return true;
168
+ }
169
+
170
+ uint32_t llama_hparams::n_layer_kv() const {
171
+ uint32_t res = 0;
172
+
173
+ for (uint32_t il = 0; il < n_layer; ++il) {
174
+ if (has_kv(il)) {
175
+ res++;
176
+ }
177
+ }
178
+
179
+ return res;
180
+ }
@@ -41,6 +41,7 @@ struct llama_hparams {
41
41
  uint32_t n_embd;
42
42
  uint32_t n_embd_features = 0;
43
43
  uint32_t n_layer;
44
+ int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
44
45
  uint32_t n_rot;
45
46
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
46
47
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -221,6 +222,11 @@ struct llama_hparams {
221
222
  uint32_t n_pos_per_embd() const;
222
223
 
223
224
  bool is_swa(uint32_t il) const;
225
+
226
+ bool has_kv(uint32_t il) const;
227
+
228
+ // number of layers for which has_kv() returns true
229
+ uint32_t n_layer_kv() const;
224
230
  };
225
231
 
226
232
  static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
@@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
22
22
  uint32_t kv_size,
23
23
  uint32_t n_seq_max,
24
24
  uint32_t n_ubatch,
25
- uint32_t n_pad) : hparams(model.hparams), unified(unified) {
26
- llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
27
- llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
25
+ uint32_t n_pad,
26
+ const layer_filter_cb & filter,
27
+ const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28
+
29
+ // chain filters
30
+ const layer_filter_cb filter_base = [&](int32_t il) {
31
+ if (filter && !filter(il)) {
32
+ return false;
33
+ }
34
+
35
+ return !model.hparams.is_swa(il);
36
+ };
37
+
38
+ const layer_filter_cb filter_swa = [&](int32_t il) {
39
+ if (filter && !filter(il)) {
40
+ return false;
41
+ }
42
+
43
+ return model.hparams.is_swa(il);
44
+ };
28
45
 
29
46
  const uint32_t size_base = kv_size;
30
47
 
@@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
41
58
  LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
42
59
 
43
60
  kv_base = std::make_unique<llama_kv_cache>(
44
- model, std::move(filter_base), type_k, type_v,
61
+ model, type_k, type_v,
45
62
  v_trans, offload, unified, size_base, n_seq_max, n_pad,
46
- 0, LLAMA_SWA_TYPE_NONE);
63
+ 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
47
64
 
48
65
  LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
49
66
 
50
67
  kv_swa = std::make_unique<llama_kv_cache>(
51
- model, std::move(filter_swa), type_k, type_v,
68
+ model, type_k, type_v,
52
69
  v_trans, offload, unified, size_swa, n_seq_max, n_pad,
53
- hparams.n_swa, hparams.swa_type);
70
+ hparams.n_swa, hparams.swa_type, filter_swa, reuse);
54
71
  }
55
72
 
56
73
  void llama_kv_cache_iswa::clear(bool data) {
@@ -20,11 +20,13 @@ public:
20
20
  bool v_trans,
21
21
  bool offload,
22
22
  bool swa_full,
23
- bool ,
23
+ bool unified,
24
24
  uint32_t kv_size,
25
25
  uint32_t n_seq_max,
26
26
  uint32_t n_ubatch,
27
- uint32_t n_pad);
27
+ uint32_t n_pad,
28
+ const layer_filter_cb & filter,
29
+ const layer_reuse_cb & reuse);
28
30
 
29
31
  ~llama_kv_cache_iswa() = default;
30
32
 
@@ -17,32 +17,25 @@
17
17
  //
18
18
 
19
19
  llama_kv_cache::llama_kv_cache(
20
- const llama_model & model,
21
- layer_filter_cb && filter,
22
- ggml_type type_k,
23
- ggml_type type_v,
24
- bool v_trans,
25
- bool offload,
26
- bool unified,
27
- uint32_t kv_size,
28
- uint32_t n_seq_max,
29
- uint32_t n_pad,
30
- uint32_t n_swa,
31
- llama_swa_type swa_type) :
20
+ const llama_model & model,
21
+ ggml_type type_k,
22
+ ggml_type type_v,
23
+ bool v_trans,
24
+ bool offload,
25
+ bool unified,
26
+ uint32_t kv_size,
27
+ uint32_t n_seq_max,
28
+ uint32_t n_pad,
29
+ uint32_t n_swa,
30
+ llama_swa_type swa_type,
31
+ const layer_filter_cb & filter,
32
+ const layer_reuse_cb & reuse) :
32
33
  model(model), hparams(model.hparams), v_trans(v_trans),
33
34
  n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
34
35
 
35
36
  GGML_ASSERT(kv_size % n_pad == 0);
36
37
 
37
- // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
38
- auto n_layer_cache = hparams.n_layer;
39
- if (model.arch == LLM_ARCH_GEMMA3N) {
40
- n_layer_cache = 20;
41
- }
42
- if (model.arch == LLM_ARCH_GLM4_MOE) {
43
- // GLM-4.5: Only process up to last layer, skip final NextN layer
44
- n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
45
- }
38
+ const uint32_t n_layer_kv = hparams.n_layer_kv();
46
39
 
47
40
  // create a context for each buffer type
48
41
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -50,7 +43,7 @@ llama_kv_cache::llama_kv_cache(
50
43
  auto it = ctx_map.find(buft);
51
44
  if (it == ctx_map.end()) {
52
45
  ggml_init_params params = {
53
- /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
46
+ /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
54
47
  /*.mem_buffer =*/ NULL,
55
48
  /*.no_alloc =*/ true,
56
49
  };
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
97
90
  __func__, hparams.n_embd_v_gqa_max());
98
91
  }
99
92
 
100
- for (uint32_t il = 0; il < n_layer_cache; il++) {
93
+ for (uint32_t il = 0; il < hparams.n_layer; il++) {
94
+ if (!hparams.has_kv(il)) {
95
+ LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
96
+ continue;
97
+ }
98
+
101
99
  if (filter && !filter(il)) {
102
- LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
100
+ LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
103
101
  continue;
104
102
  }
105
103
 
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
147
145
  layers.push_back({ il, k, v, k_stream, v_stream, });
148
146
  }
149
147
 
150
- // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
151
- if (model.arch == LLM_ARCH_GEMMA3N) {
152
- LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
148
+ if (reuse) {
149
+ LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
153
150
 
154
- for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
155
- if (filter && !filter(il)) {
156
- LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
151
+ for (uint32_t il = 0; il < hparams.n_layer; il++) {
152
+ const int32_t il_reuse = reuse(il);
153
+
154
+ if (il_reuse < 0) {
155
+ LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
157
156
  continue;
158
157
  }
159
158
 
160
- const bool is_swa = hparams.is_swa(il);
161
- const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
159
+ if (filter && !filter(il)) {
160
+ LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
161
+ continue;
162
+ }
162
163
 
163
164
  GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
165
+
164
166
  map_layer_ids[il] = map_layer_ids[il_reuse];
165
167
 
166
- LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
168
+ LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
167
169
  }
168
170
  }
169
171
 
@@ -21,9 +21,6 @@ class llama_kv_cache : public llama_memory_i {
21
21
  public:
22
22
  static uint32_t get_padding(const llama_cparams & cparams);
23
23
 
24
- // this callback is used to filter out layers that should not be included in the cache
25
- using layer_filter_cb = std::function<bool(int32_t il)>;
26
-
27
24
  struct stream_copy_info {
28
25
  bool empty() const {
29
26
  assert(ssrc.size() == sdst.size());
@@ -82,18 +79,19 @@ public:
82
79
  using slot_info_vec_t = std::vector<slot_info>;
83
80
 
84
81
  llama_kv_cache(
85
- const llama_model & model,
86
- layer_filter_cb && filter,
87
- ggml_type type_k,
88
- ggml_type type_v,
89
- bool v_trans,
90
- bool offload,
91
- bool unified,
92
- uint32_t kv_size,
93
- uint32_t n_seq_max,
94
- uint32_t n_pad,
95
- uint32_t n_swa,
96
- llama_swa_type swa_type);
82
+ const llama_model & model,
83
+ ggml_type type_k,
84
+ ggml_type type_v,
85
+ bool v_trans,
86
+ bool offload,
87
+ bool unified,
88
+ uint32_t kv_size,
89
+ uint32_t n_seq_max,
90
+ uint32_t n_pad,
91
+ uint32_t n_swa,
92
+ llama_swa_type swa_type,
93
+ const layer_filter_cb & filter,
94
+ const layer_reuse_cb & reuse);
97
95
 
98
96
  ~llama_kv_cache() = default;
99
97
 
@@ -9,32 +9,29 @@
9
9
  //
10
10
 
11
11
  llama_memory_hybrid::llama_memory_hybrid(
12
- const llama_model & model,
13
- /* attn */
14
- ggml_type type_k,
15
- ggml_type type_v,
16
- bool v_trans,
17
- uint32_t kv_size,
18
- uint32_t n_pad,
19
- uint32_t n_swa,
20
- llama_swa_type swa_type,
21
- /* recurrent */
22
- ggml_type type_r,
23
- ggml_type type_s,
24
- uint32_t rs_size,
25
- /* common */
26
- uint32_t n_seq_max,
27
- bool offload,
28
- bool unified,
29
- /* layer filters */
30
- layer_filter_cb && filter_attn,
31
- layer_filter_cb && filter_recr) :
12
+ const llama_model & model,
13
+ /* attn */
14
+ ggml_type type_k,
15
+ ggml_type type_v,
16
+ bool v_trans,
17
+ uint32_t kv_size,
18
+ uint32_t n_pad,
19
+ uint32_t n_swa,
20
+ llama_swa_type swa_type,
21
+ /* recurrent */
22
+ ggml_type type_r,
23
+ ggml_type type_s,
24
+ uint32_t rs_size,
25
+ /* common */
26
+ uint32_t n_seq_max,
27
+ bool offload,
28
+ bool unified,
29
+ /* layer filters */
30
+ const layer_filter_cb & filter_attn,
31
+ const layer_filter_cb & filter_recr) :
32
32
  hparams(model.hparams),
33
33
  mem_attn(new llama_kv_cache(
34
34
  model,
35
- filter_attn == nullptr ?
36
- [&](int32_t il) { return !hparams.is_recurrent(il); }
37
- : filter_attn,
38
35
  type_k,
39
36
  type_v,
40
37
  v_trans,
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
44
41
  n_seq_max,
45
42
  n_pad,
46
43
  n_swa,
47
- swa_type
44
+ swa_type,
45
+ filter_attn == nullptr ?
46
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
47
+ : filter_attn,
48
+ nullptr
48
49
  )),
49
50
  mem_recr(new llama_memory_recurrent(
50
51
  model,
51
- filter_recr == nullptr ?
52
- [&](int32_t il) { return hparams.is_recurrent(il); }
53
- : filter_recr,
54
52
  type_r,
55
53
  type_s,
56
54
  offload,
57
55
  rs_size,
58
- n_seq_max
56
+ n_seq_max,
57
+ filter_recr == nullptr ?
58
+ [&](int32_t il) { return hparams.is_recurrent(il); }
59
+ : filter_recr
59
60
  )) {}
60
61
 
61
62
  llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
@@ -18,31 +18,27 @@
18
18
 
19
19
  class llama_memory_hybrid : public llama_memory_i {
20
20
  public:
21
-
22
- // this callback is used to filter out layers that should not be included in the cache
23
- using layer_filter_cb = std::function<bool(int32_t il)>;
24
-
25
21
  llama_memory_hybrid(
26
22
  const llama_model & model,
27
23
  /* attn */
28
- ggml_type type_k,
29
- ggml_type type_v,
30
- bool v_trans,
31
- uint32_t kv_size,
32
- uint32_t n_pad,
33
- uint32_t n_swa,
34
- llama_swa_type swa_type,
35
- /* recurrent */
36
- ggml_type type_r,
37
- ggml_type type_s,
38
- uint32_t rs_size,
39
- /* common */
40
- uint32_t n_seq_max,
41
- bool offload,
42
- bool unified,
43
- /* layer filters */
44
- layer_filter_cb && filter_attn = nullptr,
45
- layer_filter_cb && filter_recr = nullptr);
24
+ ggml_type type_k,
25
+ ggml_type type_v,
26
+ bool v_trans,
27
+ uint32_t kv_size,
28
+ uint32_t n_pad,
29
+ uint32_t n_swa,
30
+ llama_swa_type swa_type,
31
+ /* recurrent */
32
+ ggml_type type_r,
33
+ ggml_type type_s,
34
+ uint32_t rs_size,
35
+ /* common */
36
+ uint32_t n_seq_max,
37
+ bool offload,
38
+ bool unified,
39
+ /* layer filters */
40
+ const layer_filter_cb & filter_attn = nullptr,
41
+ const layer_filter_cb & filter_recr = nullptr);
46
42
 
47
43
  ~llama_memory_hybrid() = default;
48
44
 
@@ -16,13 +16,13 @@
16
16
  //
17
17
 
18
18
  llama_memory_recurrent::llama_memory_recurrent(
19
- const llama_model & model,
20
- layer_filter_cb && filter,
21
- ggml_type type_r,
22
- ggml_type type_s,
23
- bool offload,
24
- uint32_t mem_size,
25
- uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
19
+ const llama_model & model,
20
+ ggml_type type_r,
21
+ ggml_type type_s,
22
+ bool offload,
23
+ uint32_t mem_size,
24
+ uint32_t n_seq_max,
25
+ const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
26
26
  const int32_t n_layer = hparams.n_layer;
27
27
 
28
28
  head = 0;
@@ -15,18 +15,14 @@
15
15
  // see the implementation of llama_kv_cache_context_i for an example how to do it
16
16
  class llama_memory_recurrent : public llama_memory_i {
17
17
  public:
18
-
19
- // this callback is used to filter out layers that should not be included in the cache
20
- using layer_filter_cb = std::function<bool(int32_t il)>;
21
-
22
18
  llama_memory_recurrent(
23
- const llama_model & model,
24
- layer_filter_cb && filter,
25
- ggml_type type_r,
26
- ggml_type type_s,
27
- bool offload,
28
- uint32_t mem_size,
29
- uint32_t n_seq_max);
19
+ const llama_model & model,
20
+ ggml_type type_r,
21
+ ggml_type type_s,
22
+ bool offload,
23
+ uint32_t mem_size,
24
+ uint32_t n_seq_max,
25
+ const layer_filter_cb & filter);
30
26
 
31
27
  ~llama_memory_recurrent() = default;
32
28
 
@@ -3,6 +3,7 @@
3
3
  #include "llama.h"
4
4
 
5
5
  #include <memory>
6
+ #include <functional>
6
7
 
7
8
  struct llama_ubatch;
8
9
 
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
64
65
  // general concept of LLM memory
65
66
  // the KV cache is a type of LLM memory, but there can be other types
66
67
  struct llama_memory_i {
68
+ // this callback is used to filter out layers that should not be included in the cache
69
+ using layer_filter_cb = std::function<bool(int32_t il)>;
70
+
71
+ // this callback is used to specify which layers should reuse memory from other layers
72
+ // return negative value to indicate that the layer il should not reuse memory
73
+ using layer_reuse_cb = std::function<int32_t(int32_t il)>;
74
+
67
75
  virtual ~llama_memory_i() = default;
68
76
 
69
77
  // split the input batch into a set of ubatches and verify that they can fit into the cache
@@ -1115,6 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1115
1115
  hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1116
1116
  hparams.set_swa_pattern(5);
1117
1117
 
1118
+ hparams.n_layer_kv_from_start = 20;
1118
1119
  hparams.rope_freq_base_train_swa = 10000.0f;
1119
1120
  hparams.rope_freq_scale_train_swa = 1.0f;
1120
1121
  hparams.f_attention_scale = 1.0f;
@@ -1474,12 +1475,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1474
1475
  // Expert gating function (GLM-4.5 uses sigmoid)
1475
1476
  ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
1476
1477
  if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
1477
- hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
1478
+ hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
1478
1479
  }
1479
1480
 
1480
1481
  // NextN/MTP parameters
1481
1482
  ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
1482
1483
 
1484
+ // TODO: when MTP is implemented, this should probably be updated if needed
1485
+ hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
1486
+
1483
1487
  switch (hparams.n_layer) {
1484
1488
  case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
1485
1489
  case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
@@ -10524,7 +10528,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
10524
10528
  const int64_t n_embd_altup;
10525
10529
  const int64_t n_altup;
10526
10530
  const int i_altup_act;
10527
- const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
10528
10531
  const int n_layer_sparsity = 10; // number of layers using activation sparsity
10529
10532
  const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
10530
10533
 
@@ -10574,8 +10577,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
10574
10577
 
10575
10578
  for (int il = 0; il < n_layer; ++il) {
10576
10579
  // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
10577
- const bool has_kv = (il < n_layer_kv);
10578
-
10579
10580
  const float freq_base_l = model.get_rope_freq_base (cparams, il);
10580
10581
  const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
10581
10582
 
@@ -10595,7 +10596,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
10595
10596
  ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
10596
10597
 
10597
10598
  // self-attention
10598
- if (has_kv) {
10599
+ if (hparams.has_kv(il)) {
10599
10600
  // compute Q and K and RoPE them
10600
10601
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
10601
10602
  cb(Qcur, "Qcur", il);
@@ -10635,7 +10636,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
10635
10636
  model.layers[il].wo, NULL,
10636
10637
  Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
10637
10638
  } else {
10638
- // no KV layers
10639
+ // reuse KV cache of earlier layers
10639
10640
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
10640
10641
  cb(Qcur, "Qcur", il);
10641
10642
  Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -18256,12 +18257,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
18256
18257
  if (llm_arch_is_recurrent(arch)) {
18257
18258
  res = new llama_memory_recurrent(
18258
18259
  *this,
18259
- nullptr,
18260
18260
  GGML_TYPE_F32,
18261
18261
  GGML_TYPE_F32,
18262
18262
  cparams.offload_kqv,
18263
18263
  std::max((uint32_t) 1, cparams.n_seq_max),
18264
- cparams.n_seq_max);
18264
+ cparams.n_seq_max,
18265
+ nullptr);
18265
18266
  } else if (llm_arch_is_hybrid(arch)) {
18266
18267
  const auto padding = llama_kv_cache::get_padding(cparams);
18267
18268
 
@@ -18302,6 +18303,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
18302
18303
 
18303
18304
  LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
18304
18305
 
18306
+ llama_memory_i::layer_reuse_cb reuse = nullptr;
18307
+
18308
+ if (arch == LLM_ARCH_GEMMA3N) {
18309
+ reuse = [&](int32_t il) {
18310
+ if (il >= (int32_t) hparams.n_layer_kv_from_start) {
18311
+ return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
18312
+ }
18313
+
18314
+ return -1;
18315
+ };
18316
+ }
18317
+
18305
18318
  if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
18306
18319
  GGML_ASSERT(hparams.is_swa_any());
18307
18320
 
@@ -18316,13 +18329,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
18316
18329
  n_ctx_per_stream,
18317
18330
  cparams.n_seq_max,
18318
18331
  cparams.n_ubatch,
18319
- padding);
18332
+ padding,
18333
+ nullptr,
18334
+ reuse);
18320
18335
  } else {
18321
18336
  GGML_ASSERT(!hparams.is_swa_any());
18322
18337
 
18323
18338
  res = new llama_kv_cache(
18324
18339
  *this,
18325
- nullptr,
18326
18340
  params.type_k,
18327
18341
  params.type_v,
18328
18342
  !cparams.flash_attn,
@@ -18332,7 +18346,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
18332
18346
  cparams.n_seq_max,
18333
18347
  padding,
18334
18348
  hparams.n_swa,
18335
- hparams.swa_type);
18349
+ hparams.swa_type,
18350
+ nullptr,
18351
+ nullptr);
18336
18352
  }
18337
18353
  }
18338
18354
  }