@fugood/llama.node 1.1.9 → 1.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/lib/binding.ts +7 -1
  2. package/package.json +14 -14
  3. package/scripts/llama.cpp.patch +15 -5
  4. package/src/LlamaCompletionWorker.cpp +12 -3
  5. package/src/LlamaCompletionWorker.h +3 -1
  6. package/src/LlamaContext.cpp +20 -2
  7. package/src/llama.cpp/common/arg.cpp +29 -19
  8. package/src/llama.cpp/common/chat.cpp +153 -3
  9. package/src/llama.cpp/common/chat.h +1 -0
  10. package/src/llama.cpp/common/common.cpp +10 -3
  11. package/src/llama.cpp/common/common.h +4 -1
  12. package/src/llama.cpp/ggml/CMakeLists.txt +1 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -4
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -1
  15. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  16. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +14 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +16 -12
  20. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  21. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  22. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +103 -1
  23. package/src/llama.cpp/include/llama.h +27 -1
  24. package/src/llama.cpp/src/llama-adapter.cpp +68 -4
  25. package/src/llama.cpp/src/llama-adapter.h +3 -0
  26. package/src/llama.cpp/src/llama-arch.cpp +46 -2
  27. package/src/llama.cpp/src/llama-arch.h +4 -0
  28. package/src/llama.cpp/src/llama-context.cpp +80 -39
  29. package/src/llama.cpp/src/llama-context.h +0 -4
  30. package/src/llama.cpp/src/llama-graph.cpp +20 -10
  31. package/src/llama.cpp/src/llama-graph.h +2 -1
  32. package/src/llama.cpp/src/llama-hparams.cpp +25 -0
  33. package/src/llama.cpp/src/llama-hparams.h +6 -0
  34. package/src/llama.cpp/src/llama-impl.h +2 -0
  35. package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +24 -7
  36. package/src/llama.cpp/src/llama-kv-cache-iswa.h +4 -2
  37. package/src/llama.cpp/src/llama-kv-cache.cpp +67 -130
  38. package/src/llama.cpp/src/llama-kv-cache.h +16 -28
  39. package/src/llama.cpp/src/llama-memory-hybrid.cpp +29 -28
  40. package/src/llama.cpp/src/llama-memory-hybrid.h +18 -22
  41. package/src/llama.cpp/src/llama-memory-recurrent.cpp +7 -7
  42. package/src/llama.cpp/src/llama-memory-recurrent.h +7 -11
  43. package/src/llama.cpp/src/llama-memory.h +8 -0
  44. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  45. package/src/llama.cpp/src/llama-model.cpp +302 -31
  46. package/src/llama.cpp/src/llama-model.h +1 -0
  47. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  48. package/src/llama.cpp/src/llama.cpp +12 -0
@@ -41,7 +41,6 @@ llama_context::llama_context(
41
41
  cparams.yarn_beta_slow = params.yarn_beta_slow;
42
42
  cparams.embeddings = params.embeddings;
43
43
  cparams.offload_kqv = params.offload_kqv;
44
- cparams.flash_attn = params.flash_attn;
45
44
  cparams.no_perf = params.no_perf;
46
45
  cparams.pooling_type = params.pooling_type;
47
46
  cparams.warmup = false;
@@ -86,6 +85,8 @@ llama_context::llama_context(
86
85
  cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
87
86
  }
88
87
 
88
+ cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
89
+
89
90
  // with causal attention, the batch size is limited by the context size
90
91
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
91
92
 
@@ -102,16 +103,6 @@ llama_context::llama_context(
102
103
  cparams.op_offload = params.op_offload;
103
104
  cparams.kv_unified = params.kv_unified;
104
105
 
105
- {
106
- const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
107
- supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
108
-
109
- if (!supports_set_rows && !cparams.kv_unified) {
110
- LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
111
- cparams.kv_unified = true;
112
- }
113
- }
114
-
115
106
  {
116
107
  const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
117
108
  graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -129,7 +120,7 @@ llama_context::llama_context(
129
120
  LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
130
121
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
131
122
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
132
- LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
123
+ LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
133
124
  LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
134
125
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
135
126
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -279,8 +270,8 @@ llama_context::llama_context(
279
270
  }
280
271
  }
281
272
 
282
- // reserve worst-case graph
283
- if (!hparams.vocab_only && memory) {
273
+ // resolve automatic Flash Attention use and reserve worst-case graph
274
+ if (!hparams.vocab_only) {
284
275
  const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
285
276
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
286
277
 
@@ -292,11 +283,13 @@ llama_context::llama_context(
292
283
  int n_splits_tg = -1;
293
284
  int n_nodes_tg = -1;
294
285
 
295
- // simulate full KV cache
296
-
297
- const auto mctx = memory->init_full();
298
- if (!mctx) {
299
- throw std::runtime_error("failed to initialize KV cache");
286
+ llama_memory_context_ptr mctx;
287
+ if (memory) {
288
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
289
+ mctx = memory->init_full();
290
+ if (!mctx) {
291
+ throw std::runtime_error("failed to initialize memory module");
292
+ }
300
293
  }
301
294
 
302
295
  cross.v_embd.clear();
@@ -308,6 +301,48 @@ llama_context::llama_context(
308
301
  throw std::runtime_error("failed to allocate compute pp buffers");
309
302
  }
310
303
 
304
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
305
+ ggml_backend_sched_alloc_graph(sched.get(), gf);
306
+
307
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
308
+ bool fa_device_mismatch = false;
309
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
310
+ ggml_tensor * n = ggml_graph_node(gf, i);
311
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
312
+ continue;
313
+ }
314
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
315
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
316
+
317
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
318
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
319
+ const int il = std::stoi(n->name + prefix_len);
320
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
321
+ if (device_fa != device_kv) {
322
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
323
+ "is assigned to device %s (usually due to missing support)\n",
324
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
325
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
326
+ fa_device_mismatch = true;
327
+ break;
328
+ }
329
+ }
330
+ if (fa_device_mismatch) {
331
+ cparams.flash_attn = false;
332
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
333
+ if (ggml_is_quantized(params.type_v)) {
334
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
335
+ }
336
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
337
+ if (!gf) {
338
+ throw std::runtime_error("failed to allocate compute pp buffers");
339
+ }
340
+ } else {
341
+ cparams.flash_attn = true;
342
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
343
+ }
344
+ }
345
+
311
346
  n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
312
347
  n_nodes_pp = ggml_graph_n_nodes(gf);
313
348
  }
@@ -888,12 +923,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
888
923
  }
889
924
  }
890
925
 
891
- if (!supports_set_rows) {
892
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
893
- // overlap with device computation.
894
- ggml_backend_sched_reset(sched.get());
895
- }
896
-
897
926
  // TODO: hacky solution
898
927
  if (model.arch == LLM_ARCH_T5 && t_embd) {
899
928
  //cross.t_embd = t_embd;
@@ -1056,7 +1085,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1056
1085
  const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1057
1086
 
1058
1087
  if (!res) {
1059
- // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1088
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
1060
1089
  llama_pos pos_min[LLAMA_MAX_SEQ];
1061
1090
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1062
1091
  pos_min[s] = std::numeric_limits<llama_pos>::max();
@@ -1073,7 +1102,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1073
1102
  continue;
1074
1103
  }
1075
1104
 
1076
- LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1105
+ LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1077
1106
 
1078
1107
  memory->seq_rm(s, pos_min[s], -1);
1079
1108
  }
@@ -1224,12 +1253,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1224
1253
  // wait for the computation to finish (automatically done when obtaining the model output)
1225
1254
  //synchronize();
1226
1255
 
1227
- if (!supports_set_rows) {
1228
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1229
- // overlap with device computation.
1230
- ggml_backend_sched_reset(sched.get());
1231
- }
1232
-
1233
1256
  return 0;
1234
1257
  }
1235
1258
 
@@ -1857,7 +1880,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1857
1880
  }
1858
1881
 
1859
1882
  if (memory != nullptr) {
1860
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1883
+ LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1861
1884
  memory->state_write(io);
1862
1885
  }
1863
1886
 
@@ -1943,7 +1966,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1943
1966
  }
1944
1967
 
1945
1968
  if (memory) {
1946
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1969
+ LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
1947
1970
 
1948
1971
  memory->state_read(io);
1949
1972
  }
@@ -2228,6 +2251,7 @@ llama_context_params llama_context_default_params() {
2228
2251
  /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2229
2252
  /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2230
2253
  /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2254
+ /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
2231
2255
  /*.rope_freq_base =*/ 0.0f,
2232
2256
  /*.rope_freq_scale =*/ 0.0f,
2233
2257
  /*.yarn_ext_factor =*/ -1.0f,
@@ -2244,7 +2268,6 @@ llama_context_params llama_context_default_params() {
2244
2268
  /*.abort_callback_data =*/ nullptr,
2245
2269
  /*.embeddings =*/ false,
2246
2270
  /*.offload_kqv =*/ true,
2247
- /*.flash_attn =*/ false,
2248
2271
  /*.no_perf =*/ true,
2249
2272
  /*.op_offload =*/ true,
2250
2273
  /*.swa_full =*/ true,
@@ -2272,12 +2295,30 @@ llama_context * llama_init_from_model(
2272
2295
  return nullptr;
2273
2296
  }
2274
2297
 
2275
- if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2298
+ if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
2276
2299
  LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2277
- params.flash_attn = false;
2300
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2301
+ }
2302
+
2303
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2304
+ const uint32_t blck_size = ggml_blck_size(params.type_k);
2305
+ if (model->hparams.n_embd_head_k % blck_size != 0) {
2306
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2307
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2308
+ return nullptr;
2309
+ }
2310
+ }
2311
+
2312
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2313
+ const uint32_t blck_size = ggml_blck_size(params.type_v);
2314
+ if (model->hparams.n_embd_head_v % blck_size != 0) {
2315
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2316
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2317
+ return nullptr;
2318
+ }
2278
2319
  }
2279
2320
 
2280
- if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2321
+ if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
2281
2322
  LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2282
2323
  return nullptr;
2283
2324
  }
@@ -283,10 +283,6 @@ private:
283
283
 
284
284
  bool has_evaluated_once = false;
285
285
 
286
- // env: LLAMA_SET_ROWS (temporary)
287
- // ref: https://github.com/ggml-org/llama.cpp/pull/14285
288
- bool supports_set_rows = true;
289
-
290
286
  // env: LLAMA_GRAPH_REUSE_DISABLE
291
287
  bool graph_reuse_disable = false;
292
288
 
@@ -314,8 +314,6 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
314
314
  res &= self_kq_mask->ne[0] == mctx->get_n_kv();
315
315
  res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
316
316
 
317
- res &= mctx->get_supports_set_rows(); // TODO: tmp
318
-
319
317
  return res;
320
318
  }
321
319
 
@@ -350,8 +348,6 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
350
348
  res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
351
349
  res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
352
350
 
353
- res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
354
-
355
351
  return res;
356
352
  }
357
353
 
@@ -1225,7 +1221,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1225
1221
  ggml_tensor * kq_mask,
1226
1222
  ggml_tensor * sinks,
1227
1223
  ggml_tensor * v_mla,
1228
- float kq_scale) const {
1224
+ float kq_scale,
1225
+ int il) const {
1229
1226
  const bool v_trans = v->nb[1] > v->nb[2];
1230
1227
 
1231
1228
  // split the batch into streams if needed
@@ -1260,6 +1257,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1260
1257
 
1261
1258
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1262
1259
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1260
+ cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1263
1261
 
1264
1262
  ggml_flash_attn_ext_add_sinks(cur, sinks);
1265
1263
  ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
@@ -1275,6 +1273,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1275
1273
  // The permutations are noops and only change how the tensor data is interpreted.
1276
1274
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1277
1275
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1276
+ cb(cur, "fattn_mla", il);
1278
1277
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1279
1278
  cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1280
1279
  #endif
@@ -1283,6 +1282,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1283
1282
  cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1284
1283
  } else {
1285
1284
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1285
+ cb(kq, "kq", il);
1286
1286
 
1287
1287
  // note: this op tends to require high floating point range
1288
1288
  // while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1296,32 +1296,42 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1296
1296
  // before the softmax below
1297
1297
 
1298
1298
  kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
1299
+ cb(kq, "kq_tanh", il);
1299
1300
  kq = ggml_scale(ctx0, kq, 30);
1301
+ cb(kq, "kq_scaled", il);
1300
1302
  }
1301
1303
 
1302
1304
  if (hparams.attn_soft_cap) {
1303
1305
  kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1306
+ cb(kq, "kq_scaled_1", il);
1304
1307
  kq = ggml_tanh (ctx0, kq);
1308
+ cb(kq, "kq_tanh", il);
1305
1309
  kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1310
+ cb(kq, "kq_scaled_2", il);
1306
1311
  }
1307
1312
 
1308
1313
  if (kq_b) {
1309
1314
  kq = ggml_add(ctx0, kq, kq_b);
1315
+ cb(kq, "kq_plus_kq_b", il);
1310
1316
  }
1311
1317
 
1312
1318
  kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1313
1319
  ggml_soft_max_add_sinks(kq, sinks);
1320
+ cb(kq, "kq_soft_max", il);
1314
1321
 
1315
1322
  if (!v_trans) {
1316
1323
  // note: avoid this branch
1317
1324
  v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1325
+ cb(v, "v_cont", il);
1318
1326
  }
1319
1327
 
1320
1328
  ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1329
+ cb(kqv, "kqv", il);
1321
1330
 
1322
1331
  // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1323
1332
  if (v_mla) {
1324
1333
  kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1334
+ cb(kqv, "kqv_mla", il);
1325
1335
  }
1326
1336
 
1327
1337
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
@@ -1376,13 +1386,13 @@ ggml_tensor * llm_graph_context::build_attn(
1376
1386
 
1377
1387
  // [TAG_NO_CACHE_PAD]
1378
1388
  // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1379
- assert(!ubatch.equal_seqs());
1389
+ assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1380
1390
 
1381
1391
  ggml_tensor * q = q_cur;
1382
1392
  ggml_tensor * k = k_cur;
1383
1393
  ggml_tensor * v = v_cur;
1384
1394
 
1385
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1395
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1386
1396
  cb(cur, "kqv_out", il);
1387
1397
 
1388
1398
  if (wo) {
@@ -1471,7 +1481,7 @@ ggml_tensor * llm_graph_context::build_attn(
1471
1481
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1472
1482
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1473
1483
 
1474
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1484
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1475
1485
  cb(cur, "kqv_out", il);
1476
1486
 
1477
1487
  if (wo) {
@@ -1538,7 +1548,7 @@ ggml_tensor * llm_graph_context::build_attn(
1538
1548
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1539
1549
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1540
1550
 
1541
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1551
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1542
1552
  cb(cur, "kqv_out", il);
1543
1553
 
1544
1554
  if (wo) {
@@ -1593,7 +1603,7 @@ ggml_tensor * llm_graph_context::build_attn(
1593
1603
  ggml_tensor * k = k_cur;
1594
1604
  ggml_tensor * v = v_cur;
1595
1605
 
1596
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1606
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1597
1607
  cb(cur, "kqv_out", il);
1598
1608
 
1599
1609
  if (wo) {
@@ -687,7 +687,8 @@ struct llm_graph_context {
687
687
  ggml_tensor * kq_mask,
688
688
  ggml_tensor * sinks, // [n_head_q]
689
689
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690
- float kq_scale) const;
690
+ float kq_scale,
691
+ int il) const;
691
692
 
692
693
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
693
694
 
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
153
153
 
154
154
  GGML_ABORT("fatal error");
155
155
  }
156
+
157
+ bool llama_hparams::has_kv(uint32_t il) const {
158
+ if (n_layer_kv_from_start >= 0) {
159
+ if (il < (uint32_t) n_layer_kv_from_start) {
160
+ return true;
161
+ }
162
+
163
+ return false;
164
+ }
165
+
166
+ // by default, all layers have kv
167
+ return true;
168
+ }
169
+
170
+ uint32_t llama_hparams::n_layer_kv() const {
171
+ uint32_t res = 0;
172
+
173
+ for (uint32_t il = 0; il < n_layer; ++il) {
174
+ if (has_kv(il)) {
175
+ res++;
176
+ }
177
+ }
178
+
179
+ return res;
180
+ }
@@ -41,6 +41,7 @@ struct llama_hparams {
41
41
  uint32_t n_embd;
42
42
  uint32_t n_embd_features = 0;
43
43
  uint32_t n_layer;
44
+ int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
44
45
  uint32_t n_rot;
45
46
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
46
47
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -221,6 +222,11 @@ struct llama_hparams {
221
222
  uint32_t n_pos_per_embd() const;
222
223
 
223
224
  bool is_swa(uint32_t il) const;
225
+
226
+ bool has_kv(uint32_t il) const;
227
+
228
+ // number of layers for which has_kv() returns true
229
+ uint32_t n_layer_kv() const;
224
230
  };
225
231
 
226
232
  static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
@@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
59
59
  std::string llama_format_tensor_shape(const struct ggml_tensor * t);
60
60
 
61
61
  std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
62
+
63
+ #define LLAMA_TENSOR_NAME_FATTN "__fattn__"
@@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
22
22
  uint32_t kv_size,
23
23
  uint32_t n_seq_max,
24
24
  uint32_t n_ubatch,
25
- uint32_t n_pad) : hparams(model.hparams), unified(unified) {
26
- llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
27
- llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
25
+ uint32_t n_pad,
26
+ const layer_filter_cb & filter,
27
+ const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28
+
29
+ // chain filters
30
+ const layer_filter_cb filter_base = [&](int32_t il) {
31
+ if (filter && !filter(il)) {
32
+ return false;
33
+ }
34
+
35
+ return !model.hparams.is_swa(il);
36
+ };
37
+
38
+ const layer_filter_cb filter_swa = [&](int32_t il) {
39
+ if (filter && !filter(il)) {
40
+ return false;
41
+ }
42
+
43
+ return model.hparams.is_swa(il);
44
+ };
28
45
 
29
46
  const uint32_t size_base = kv_size;
30
47
 
@@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
41
58
  LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
42
59
 
43
60
  kv_base = std::make_unique<llama_kv_cache>(
44
- model, std::move(filter_base), type_k, type_v,
61
+ model, type_k, type_v,
45
62
  v_trans, offload, unified, size_base, n_seq_max, n_pad,
46
- 0, LLAMA_SWA_TYPE_NONE);
63
+ 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
47
64
 
48
65
  LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
49
66
 
50
67
  kv_swa = std::make_unique<llama_kv_cache>(
51
- model, std::move(filter_swa), type_k, type_v,
68
+ model, type_k, type_v,
52
69
  v_trans, offload, unified, size_swa, n_seq_max, n_pad,
53
- hparams.n_swa, hparams.swa_type);
70
+ hparams.n_swa, hparams.swa_type, filter_swa, reuse);
54
71
  }
55
72
 
56
73
  void llama_kv_cache_iswa::clear(bool data) {
@@ -20,11 +20,13 @@ public:
20
20
  bool v_trans,
21
21
  bool offload,
22
22
  bool swa_full,
23
- bool ,
23
+ bool unified,
24
24
  uint32_t kv_size,
25
25
  uint32_t n_seq_max,
26
26
  uint32_t n_ubatch,
27
- uint32_t n_pad);
27
+ uint32_t n_pad,
28
+ const layer_filter_cb & filter,
29
+ const layer_reuse_cb & reuse);
28
30
 
29
31
  ~llama_kv_cache_iswa() = default;
30
32