@fugood/llama.node 1.0.0-beta.7 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/lib/binding.ts +10 -0
  3. package/lib/index.js +8 -0
  4. package/lib/index.ts +14 -0
  5. package/package.json +14 -14
  6. package/src/LlamaContext.cpp +37 -0
  7. package/src/LlamaContext.h +1 -0
  8. package/src/RerankWorker.h +26 -0
  9. package/src/llama.cpp/CMakeLists.txt +1 -1
  10. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  11. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  12. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -0
  13. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +8 -0
  14. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  15. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  16. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +13 -12
  17. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  18. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  19. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  20. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  21. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  22. package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  23. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  24. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  25. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  26. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  27. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +59 -16
  28. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -0
  29. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  30. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +48 -48
  31. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  32. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +15 -14
  33. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  34. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  35. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  36. package/src/llama.cpp/include/llama.h +6 -3
  37. package/src/llama.cpp/src/llama-arch.cpp +54 -0
  38. package/src/llama.cpp/src/llama-arch.h +17 -0
  39. package/src/llama.cpp/src/llama-batch.cpp +20 -7
  40. package/src/llama.cpp/src/llama-chat.cpp +11 -6
  41. package/src/llama.cpp/src/llama-context.cpp +0 -1
  42. package/src/llama.cpp/src/llama-graph.cpp +19 -4
  43. package/src/llama.cpp/src/llama-graph.h +14 -2
  44. package/src/llama.cpp/src/llama-hparams.h +6 -0
  45. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +28 -2
  46. package/src/llama.cpp/src/llama-kv-cells.h +33 -9
  47. package/src/llama.cpp/src/llama-model.cpp +518 -1
  48. package/src/llama.cpp/src/llama-model.h +22 -0
  49. package/src/llama.cpp/src/llama-quant.cpp +87 -5
@@ -390,6 +390,7 @@ extern "C" {
390
390
  void * imatrix; // pointer to importance matrix data
391
391
  void * kv_overrides; // pointer to vector containing overrides
392
392
  void * tensor_types; // pointer to vector containing tensor types
393
+ void * prune_layers; // pointer to vector containing layer indices to prune
393
394
  } llama_model_quantize_params;
394
395
 
395
396
  typedef struct llama_logit_bias {
@@ -943,12 +944,14 @@ extern "C" {
943
944
  // Requires the context to have a memory.
944
945
  // For encode-decoder contexts, processes the batch using the decoder.
945
946
  // Positive return values does not mean a fatal error, but rather a warning.
946
- // Upon non-zero return values, the memory state is restored to the state before this call
947
+ // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
948
+ // To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
949
+ // Upon other return values, the memory state is restored to the state before this call
947
950
  // 0 - success
948
951
  // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
949
- // 2 - aborted
952
+ // 2 - aborted (processed ubatches will remain in the context's memory)
950
953
  // -1 - invalid input batch
951
- // < -1 - error
954
+ // < -1 - fatal error (processed ubatches will remain in the context's memory)
952
955
  LLAMA_API int32_t llama_decode(
953
956
  struct llama_context * ctx,
954
957
  struct llama_batch batch);
@@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
42
42
  { LLM_ARCH_GEMMA, "gemma" },
43
43
  { LLM_ARCH_GEMMA2, "gemma2" },
44
44
  { LLM_ARCH_GEMMA3, "gemma3" },
45
+ { LLM_ARCH_GEMMA3N, "gemma3n" },
45
46
  { LLM_ARCH_STARCODER2, "starcoder2" },
46
47
  { LLM_ARCH_MAMBA, "mamba" },
47
48
  { LLM_ARCH_XVERSE, "xverse" },
@@ -932,6 +933,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
932
933
  { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
933
934
  },
934
935
  },
936
+ {
937
+ LLM_ARCH_GEMMA3N,
938
+ {
939
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
940
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
941
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
942
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
943
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
944
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
945
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
946
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
947
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
948
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
949
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
950
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
951
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
952
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
953
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
954
+ { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
955
+ { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
956
+ { LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
957
+ { LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
958
+ { LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
959
+ { LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
960
+ { LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
961
+ { LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
962
+ { LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
963
+ { LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
964
+ { LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
965
+ { LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
966
+ { LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
967
+ { LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
968
+ { LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
969
+ { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
970
+ },
971
+ },
935
972
  {
936
973
  LLM_ARCH_STARCODER2,
937
974
  {
@@ -1749,6 +1786,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1749
1786
  {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1750
1787
  {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
1751
1788
  {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
1789
+ // altup / laurel (gemma 3n)
1790
+ {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
1791
+ {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1792
+ {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
1793
+ {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1794
+ {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
1795
+ {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1796
+ {LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1797
+ {LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1798
+ {LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1799
+ {LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1800
+ {LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1801
+ {LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1802
+ {LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1803
+ {LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1804
+ {LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1805
+ {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1752
1806
  // this tensor is loaded for T5, but never used
1753
1807
  {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
1754
1808
  {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
@@ -46,6 +46,7 @@ enum llm_arch {
46
46
  LLM_ARCH_GEMMA,
47
47
  LLM_ARCH_GEMMA2,
48
48
  LLM_ARCH_GEMMA3,
49
+ LLM_ARCH_GEMMA3N,
49
50
  LLM_ARCH_STARCODER2,
50
51
  LLM_ARCH_MAMBA,
51
52
  LLM_ARCH_XVERSE,
@@ -269,6 +270,22 @@ enum llm_tensor {
269
270
  LLM_TENSOR_LAYER_OUT_NORM,
270
271
  LLM_TENSOR_POST_ATTN_NORM,
271
272
  LLM_TENSOR_POST_MLP_NORM,
273
+ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
274
+ LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
275
+ LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
276
+ LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
277
+ LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
278
+ LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
279
+ LLM_TENSOR_ALTUP_PROJ, // gemma3n
280
+ LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
281
+ LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
282
+ LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
283
+ LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
284
+ LLM_TENSOR_ALTUP_ROUTER, // gemma3n
285
+ LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
286
+ LLM_TENSOR_LAUREL_L, // gemma3n
287
+ LLM_TENSOR_LAUREL_R, // gemma3n
288
+ LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
272
289
  LLM_TENSOR_SSM_IN,
273
290
  LLM_TENSOR_SSM_CONV1D,
274
291
  LLM_TENSOR_SSM_X,
@@ -244,22 +244,35 @@ bool llama_batch_allocr::init(
244
244
  continue;
245
245
  }
246
246
 
247
- if (memory) {
247
+ const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
248
+
249
+ if (p0 >= 0) {
250
+ bool ok = true;
251
+
248
252
  if (batch.token) {
249
- if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
250
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
251
- return false;
253
+ if (seq_pos_min(s) != p0 + 1) {
254
+ ok = false;
252
255
  }
253
256
  } else {
254
257
  assert(batch.embd);
255
258
 
256
259
  // for embeddings (typically used as vision input), we allow them to have repeating positions
257
260
  // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
258
- if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
259
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
260
- return false;
261
+ if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
262
+ ok = false;
261
263
  }
262
264
  }
265
+
266
+ if (!ok) {
267
+ LLAMA_LOG_ERROR(
268
+ "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
269
+ " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
270
+ " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
271
+ " it is required that the sequence positions remain consecutive: Y = X + 1\n",
272
+ __func__, s, s, p0, s, seq_pos_min(s));
273
+
274
+ return false;
275
+ }
263
276
  }
264
277
 
265
278
  if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
@@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
528
528
  }
529
529
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
530
530
  // this template requires the model to have "\n\n" as EOT token
531
- for (auto message : chat) {
532
- std::string role(message->role);
533
- if (role == "user") {
534
- ss << "User: " << message->content << "\n\nAssistant:";
535
- } else {
536
- ss << message->content << "\n\n";
531
+ for (size_t i = 0; i < chat.size(); i++) {
532
+ std::string role(chat[i]->role);
533
+ if (role == "system") {
534
+ ss << "System: " << trim(chat[i]->content) << "\n\n";
535
+ } else if (role == "user") {
536
+ ss << "User: " << trim(chat[i]->content) << "\n\n";
537
+ if (i == chat.size() - 1) {
538
+ ss << "Assistant:";
539
+ }
540
+ } else if (role == "assistant") {
541
+ ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
537
542
  }
538
543
  }
539
544
  } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
@@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1018
1018
  pos_min[s] = std::numeric_limits<llama_pos>::max();
1019
1019
  }
1020
1020
 
1021
- // TODO: fix sequence indexing
1022
1021
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1023
1022
  const auto & seq_id = ubatch.seq_id[i][0];
1024
1023
 
@@ -350,6 +350,12 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350
350
  }
351
351
  }
352
352
 
353
+ void llm_graph_input_one::set_input(const llama_ubatch *) {
354
+ GGML_ASSERT(one && ggml_nelements(one) == 1);
355
+ float f_one = 1.0f;
356
+ ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357
+ }
358
+
353
359
  //
354
360
  // llm_graph_context
355
361
  //
@@ -1267,8 +1273,14 @@ ggml_tensor * llm_graph_context::build_attn(
1267
1273
  // these nodes are added to the graph together so that they are not reordered
1268
1274
  // by doing so, the number of splits in the graph is reduced
1269
1275
  ggml_build_forward_expand(gf, q_cur);
1270
- ggml_build_forward_expand(gf, k_cur);
1271
- ggml_build_forward_expand(gf, v_cur);
1276
+
1277
+ if (k_cur) {
1278
+ ggml_build_forward_expand(gf, k_cur);
1279
+ }
1280
+
1281
+ if (v_cur) {
1282
+ ggml_build_forward_expand(gf, v_cur);
1283
+ }
1272
1284
 
1273
1285
  const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1274
1286
 
@@ -1276,9 +1288,12 @@ ggml_tensor * llm_graph_context::build_attn(
1276
1288
 
1277
1289
  const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1278
1290
 
1279
- // store to KV cache
1280
- {
1291
+ // optionally store to KV cache
1292
+ if (k_cur) {
1281
1293
  ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1294
+ }
1295
+
1296
+ if (v_cur) {
1282
1297
  ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1283
1298
  }
1284
1299
 
@@ -329,6 +329,17 @@ public:
329
329
  const llama_memory_hybrid_context * mctx;
330
330
  };
331
331
 
332
+ // TODO: remove this when ggml_scale_add is implemented
333
+ class llm_graph_input_one : public llm_graph_input_i {
334
+ public:
335
+ llm_graph_input_one() {}
336
+ virtual ~llm_graph_input_one() = default;
337
+
338
+ void set_input(const llama_ubatch *) override;
339
+
340
+ ggml_tensor * one = nullptr; // F32
341
+ };
342
+
332
343
  //
333
344
  // llm_graph_result
334
345
  //
@@ -589,14 +600,15 @@ struct llm_graph_context {
589
600
 
590
601
  llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
591
602
 
603
+ // note: if k_cur or v_cur are not provided, they will not be stored in the memory
592
604
  ggml_tensor * build_attn(
593
605
  llm_graph_input_attn_kv_unified_iswa * inp,
594
606
  ggml_cgraph * gf,
595
607
  ggml_tensor * wo,
596
608
  ggml_tensor * wo_b,
597
609
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
598
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
599
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
610
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
611
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
600
612
  ggml_tensor * kq_b,
601
613
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
602
614
  float kq_scale,
@@ -143,6 +143,12 @@ struct llama_hparams {
143
143
  uint32_t n_attn_temp_floor_scale = 8192;
144
144
  float f_attn_temp_scale = 0.1;
145
145
 
146
+ // gemma3n altup
147
+ uint32_t n_altup = 4; // altup_num_inputs
148
+ uint32_t i_altup_act = 0; // altup_active_idx
149
+ uint32_t laurel_rank = 64;
150
+ uint32_t n_embd_altup = 256;
151
+
146
152
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
147
153
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
148
154
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
33
33
 
34
34
  GGML_ASSERT(kv_size % n_pad == 0);
35
35
 
36
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
37
+ auto n_layer_cache = hparams.n_layer;
38
+ if (model.arch == LLM_ARCH_GEMMA3N) {
39
+ n_layer_cache = 20;
40
+ }
41
+
36
42
  // create a context for each buffer type
37
43
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
38
44
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
39
45
  auto it = ctx_map.find(buft);
40
46
  if (it == ctx_map.end()) {
41
47
  ggml_init_params params = {
42
- /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
48
+ /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
43
49
  /*.mem_buffer =*/ NULL,
44
50
  /*.no_alloc =*/ true,
45
51
  };
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
62
68
 
63
69
  cells.resize(kv_size);
64
70
 
65
- for (uint32_t il = 0; il < hparams.n_layer; il++) {
71
+ for (uint32_t il = 0; il < n_layer_cache; il++) {
66
72
  if (filter && !filter(il)) {
67
73
  LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
68
74
  continue;
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
102
108
  layers.push_back({ il, k, v });
103
109
  }
104
110
 
111
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
112
+ if (model.arch == LLM_ARCH_GEMMA3N) {
113
+ LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
114
+
115
+ for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
116
+ if (filter && !filter(il)) {
117
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
118
+ continue;
119
+ }
120
+
121
+ const bool is_swa = hparams.is_swa(il);
122
+ const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
123
+
124
+ GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
125
+ map_layer_ids[il] = map_layer_ids[il_reuse];
126
+
127
+ LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
128
+ }
129
+ }
130
+
105
131
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
106
132
  for (auto it : ctx_map) {
107
133
  auto * buft = it.first;
@@ -7,6 +7,7 @@
7
7
  #include <cassert>
8
8
  #include <vector>
9
9
  #include <set>
10
+ #include <map>
10
11
 
11
12
  // meta information about KV cells that can be part of multiple sequences at the same time
12
13
  // TODO: add unit tests
@@ -164,7 +165,7 @@ public:
164
165
  assert(seq_id >= 0);
165
166
 
166
167
  seq[i].reset(seq_id);
167
- seq_pos[seq_id].erase(pos[i]);
168
+ seq_pos_dec(seq_id, pos[i]);
168
169
 
169
170
  if (seq[i].none()) {
170
171
  pos[i] = -1;
@@ -187,7 +188,7 @@ public:
187
188
  seq[i].reset();
188
189
 
189
190
  seq[i].set(seq_id);
190
- seq_pos[seq_id].insert(pos[i]);
191
+ seq_pos_inc(seq_id, pos[i]);
191
192
 
192
193
  return false;
193
194
  }
@@ -232,7 +233,7 @@ public:
232
233
  assert(!seq[i].test(seq_id));
233
234
 
234
235
  seq[i].set(seq_id);
235
- seq_pos[seq_id].insert(pos[i]);
236
+ seq_pos_inc(seq_id, pos[i]);
236
237
  }
237
238
 
238
239
  // return the sequence id of this cell
@@ -259,7 +260,9 @@ public:
259
260
  return -1;
260
261
  }
261
262
 
262
- return *seq_pos[seq_id].begin();
263
+ assert(seq_pos[seq_id].begin()->second > 0);
264
+
265
+ return seq_pos[seq_id].begin()->first;
263
266
  }
264
267
 
265
268
  // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ public:
272
275
  return -1;
273
276
  }
274
277
 
275
- return *seq_pos[seq_id].rbegin();
278
+ assert(seq_pos[seq_id].rbegin()->second > 0);
279
+
280
+ return seq_pos[seq_id].rbegin()->first;
276
281
  }
277
282
 
278
283
  // note: call only if the cell is not empty
@@ -389,17 +394,36 @@ private:
389
394
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
395
  std::vector<seq_set_t> seq;
391
396
 
392
- // the set seq_pos[s] tells us which positions are currently present for sequence s
397
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398
+ // if the position p is not present, seq_pos[s][p] is not set
393
399
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
400
+ //
401
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402
+ // - during performing a cache reuse via (rm + add)
403
+ // - some vision models have input embeddings with repeating positions
404
+ //
405
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
395
406
 
396
407
  // helper functions for updating `seq_pos`, once cell at a time:
397
408
 
409
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
410
+ auto it = seq_pos[s].find(p);
411
+ assert(it != seq_pos[s].end());
412
+
413
+ if (--it->second == 0) {
414
+ seq_pos[s].erase(it);
415
+ }
416
+ }
417
+
418
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
419
+ seq_pos[s][p]++;
420
+ }
421
+
398
422
  // remove cell i
399
423
  void seq_pos_rm(uint32_t i) {
400
424
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401
425
  if (seq[i].test(s)) {
402
- seq_pos[s].erase(pos[i]);
426
+ seq_pos_dec(s, pos[i]);
403
427
  }
404
428
  }
405
429
  }
@@ -408,7 +432,7 @@ private:
408
432
  void seq_pos_add(uint32_t i) {
409
433
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410
434
  if (seq[i].test(s)) {
411
- seq_pos[s].insert(pos[i]);
435
+ seq_pos_inc(s, pos[i]);
412
436
  }
413
437
  }
414
438
  }