@fugood/llama.node 1.1.10 → 1.2.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +20 -2
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +174 -388
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +67 -37
  27. package/src/llama.cpp/common/chat.cpp +263 -2
  28. package/src/llama.cpp/common/chat.h +4 -0
  29. package/src/llama.cpp/common/common.cpp +10 -3
  30. package/src/llama.cpp/common/common.h +5 -2
  31. package/src/llama.cpp/common/log.cpp +53 -2
  32. package/src/llama.cpp/common/log.h +10 -4
  33. package/src/llama.cpp/common/sampling.cpp +23 -2
  34. package/src/llama.cpp/common/sampling.h +3 -1
  35. package/src/llama.cpp/common/speculative.cpp +1 -1
  36. package/src/llama.cpp/ggml/CMakeLists.txt +4 -3
  37. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
  38. package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
  39. package/src/llama.cpp/ggml/include/ggml.h +50 -1
  40. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +19 -16
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  42. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -7
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
  45. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  46. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +18 -18
  48. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  49. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +234 -16
  50. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  51. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +80 -51
  52. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +161 -20
  53. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +399 -50
  54. package/src/llama.cpp/include/llama.h +32 -7
  55. package/src/llama.cpp/src/llama-adapter.cpp +101 -4
  56. package/src/llama.cpp/src/llama-adapter.h +6 -0
  57. package/src/llama.cpp/src/llama-arch.cpp +69 -2
  58. package/src/llama.cpp/src/llama-arch.h +6 -0
  59. package/src/llama.cpp/src/llama-context.cpp +92 -45
  60. package/src/llama.cpp/src/llama-context.h +1 -5
  61. package/src/llama.cpp/src/llama-graph.cpp +74 -19
  62. package/src/llama.cpp/src/llama-graph.h +10 -1
  63. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  64. package/src/llama.cpp/src/llama-hparams.h +9 -3
  65. package/src/llama.cpp/src/llama-impl.h +2 -0
  66. package/src/llama.cpp/src/llama-kv-cache.cpp +33 -120
  67. package/src/llama.cpp/src/llama-kv-cache.h +4 -13
  68. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  69. package/src/llama.cpp/src/llama-model.cpp +434 -21
  70. package/src/llama.cpp/src/llama-model.h +1 -1
  71. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  72. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  73. package/src/llama.cpp/src/llama.cpp +12 -0
  74. package/src/anyascii.c +0 -22223
  75. package/src/anyascii.h +0 -42
  76. package/src/tts_utils.cpp +0 -371
  77. package/src/tts_utils.h +0 -103
@@ -6,6 +6,7 @@
6
6
 
7
7
  #include <map>
8
8
  #include <cassert>
9
+ #include <sstream>
9
10
  #include <stdexcept>
10
11
 
11
12
  // vec
@@ -163,13 +164,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
163
164
 
164
165
  // check metadata
165
166
  {
167
+ const gguf_context * gguf_ctx = ctx_gguf.get();
168
+
169
+ LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__);
170
+
171
+ // get metadata as string
172
+ for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) {
173
+ gguf_type type = gguf_get_kv_type(gguf_ctx, i);
174
+ const std::string type_name =
175
+ type == GGUF_TYPE_ARRAY
176
+ ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i))
177
+ : gguf_type_name(type);
178
+ const char * name = gguf_get_key(gguf_ctx, i);
179
+ const std::string value = gguf_kv_to_str(gguf_ctx, i);
180
+
181
+ if (type != GGUF_TYPE_ARRAY) {
182
+ adapter.gguf_kv.emplace(name, value);
183
+ }
184
+
185
+ const size_t MAX_VALUE_LEN = 40;
186
+ std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value;
187
+ replace_all(print_value, "\n", "\\n");
188
+
189
+ LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str());
190
+ }
191
+
166
192
  auto get_kv_str = [&](const std::string & key) -> std::string {
167
- int id = gguf_find_key(ctx_gguf.get(), key.c_str());
168
- return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
193
+ int id = gguf_find_key(gguf_ctx, key.c_str());
194
+ return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id));
169
195
  };
170
196
  auto get_kv_f32 = [&](const std::string & key) -> float {
171
- int id = gguf_find_key(ctx_gguf.get(), key.c_str());
172
- return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
197
+ int id = gguf_find_key(gguf_ctx, key.c_str());
198
+ return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id);
173
199
  };
174
200
  LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
175
201
 
@@ -190,6 +216,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
190
216
  }
191
217
 
192
218
  adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
219
+
220
+ // parse alora invocation sequence vector
221
+ const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS);
222
+ const int kid = gguf_find_key(ctx_gguf.get(), key.c_str());
223
+ if (kid >= 0) {
224
+ if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) {
225
+ throw std::runtime_error("invalid gguf type for " + key);
226
+ }
227
+ const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid);
228
+ if (arr_type != GGUF_TYPE_UINT32) {
229
+ throw std::runtime_error("invalid gguf element type for " + key);
230
+ }
231
+ const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid);
232
+ const void * data = gguf_get_arr_data(ctx_gguf.get(), kid);
233
+ adapter.alora_invocation_tokens.resize(seq_len);
234
+ std::copy(
235
+ (const llama_token *)data,
236
+ (const llama_token *)data + seq_len,
237
+ adapter.alora_invocation_tokens.begin());
238
+ }
193
239
  }
194
240
 
195
241
  int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
@@ -383,6 +429,57 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
383
429
  return nullptr;
384
430
  }
385
431
 
432
+ int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) {
433
+ const auto & it = adapter->gguf_kv.find(key);
434
+ if (it == adapter->gguf_kv.end()) {
435
+ if (buf_size > 0) {
436
+ buf[0] = '\0';
437
+ }
438
+ return -1;
439
+ }
440
+ return snprintf(buf, buf_size, "%s", it->second.c_str());
441
+ }
442
+
443
+ int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) {
444
+ return (int)adapter->gguf_kv.size();
445
+ }
446
+
447
+ int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) {
448
+ if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
449
+ if (buf_size > 0) {
450
+ buf[0] = '\0';
451
+ }
452
+ return -1;
453
+ }
454
+ auto it = adapter->gguf_kv.begin();
455
+ std::advance(it, i);
456
+ return snprintf(buf, buf_size, "%s", it->first.c_str());
457
+ }
458
+
459
+ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) {
460
+ if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
461
+ if (buf_size > 0) {
462
+ buf[0] = '\0';
463
+ }
464
+ return -1;
465
+ }
466
+ auto it = adapter->gguf_kv.begin();
467
+ std::advance(it, i);
468
+ return snprintf(buf, buf_size, "%s", it->second.c_str());
469
+ }
470
+
386
471
  void llama_adapter_lora_free(llama_adapter_lora * adapter) {
387
472
  delete adapter;
388
473
  }
474
+
475
+ uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {
476
+ if (!adapter) {
477
+ return 0;
478
+ }
479
+ return adapter->alora_invocation_tokens.size();
480
+ }
481
+
482
+ const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) {
483
+ GGML_ASSERT(adapter);
484
+ return adapter->alora_invocation_tokens.data();
485
+ }
@@ -67,6 +67,12 @@ struct llama_adapter_lora {
67
67
 
68
68
  float alpha;
69
69
 
70
+ // gguf metadata
71
+ std::unordered_map<std::string, std::string> gguf_kv;
72
+
73
+ // activated lora (aLoRA)
74
+ std::vector<llama_token> alora_invocation_tokens;
75
+
70
76
  llama_adapter_lora() = default;
71
77
  ~llama_adapter_lora() = default;
72
78
 
@@ -22,6 +22,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
22
22
  { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
23
23
  { LLM_ARCH_NEO_BERT, "neo-bert" },
24
24
  { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
25
+ { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" },
25
26
  { LLM_ARCH_BLOOM, "bloom" },
26
27
  { LLM_ARCH_STABLELM, "stablelm" },
27
28
  { LLM_ARCH_QWEN, "qwen" },
@@ -44,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
44
45
  { LLM_ARCH_GEMMA2, "gemma2" },
45
46
  { LLM_ARCH_GEMMA3, "gemma3" },
46
47
  { LLM_ARCH_GEMMA3N, "gemma3n" },
48
+ { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
47
49
  { LLM_ARCH_STARCODER2, "starcoder2" },
48
50
  { LLM_ARCH_MAMBA, "mamba" },
49
51
  { LLM_ARCH_MAMBA2, "mamba2" },
@@ -68,6 +70,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
68
70
  { LLM_ARCH_T5ENCODER, "t5encoder" },
69
71
  { LLM_ARCH_JAIS, "jais" },
70
72
  { LLM_ARCH_NEMOTRON, "nemotron" },
73
+ { LLM_ARCH_NEMOTRON_H, "nemotron_h" },
71
74
  { LLM_ARCH_EXAONE, "exaone" },
72
75
  { LLM_ARCH_EXAONE4, "exaone4" },
73
76
  { LLM_ARCH_RWKV6, "rwkv6" },
@@ -234,8 +237,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
234
237
  { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
235
238
  { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
236
239
 
237
- { LLM_KV_ADAPTER_TYPE, "adapter.type" },
238
- { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
240
+ { LLM_KV_ADAPTER_TYPE, "adapter.type" },
241
+ { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
242
+ { LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
243
+ { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
244
+ { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },
239
245
 
240
246
  // deprecated
241
247
  { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
@@ -575,6 +581,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
575
581
  { LLM_TENSOR_CLS, "cls" },
576
582
  },
577
583
  },
584
+ {
585
+ LLM_ARCH_JINA_BERT_V3,
586
+ {
587
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
588
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
589
+ { LLM_TENSOR_TOKEN_TYPES, "token_types" },
590
+ { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
591
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
592
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
593
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
594
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
595
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
596
+ },
597
+ },
578
598
  {
579
599
  LLM_ARCH_BLOOM,
580
600
  {
@@ -1020,6 +1040,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1020
1040
  { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
1021
1041
  },
1022
1042
  },
1043
+ {
1044
+ LLM_ARCH_GEMMA_EMBEDDING,
1045
+ {
1046
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1047
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1048
+ { LLM_TENSOR_OUTPUT, "output" },
1049
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1050
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1051
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1052
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1053
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1054
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1055
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1056
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1057
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1058
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1059
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1060
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1061
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1062
+ },
1063
+ },
1023
1064
  {
1024
1065
  LLM_ARCH_STARCODER2,
1025
1066
  {
@@ -1533,6 +1574,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1533
1574
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1534
1575
  },
1535
1576
  },
1577
+ {
1578
+ LLM_ARCH_NEMOTRON_H,
1579
+ {
1580
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1581
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1582
+ { LLM_TENSOR_OUTPUT, "output" },
1583
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1584
+ // mamba(2) ssm layers
1585
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1586
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1587
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1588
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1589
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1590
+ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1591
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1592
+ // attention layers
1593
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1594
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1595
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1596
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1597
+ // dense FFN
1598
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1599
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1600
+ },
1601
+ },
1536
1602
  {
1537
1603
  LLM_ARCH_EXAONE,
1538
1604
  {
@@ -2338,6 +2404,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
2338
2404
  case LLM_ARCH_PLAMO2:
2339
2405
  case LLM_ARCH_GRANITE_HYBRID:
2340
2406
  case LLM_ARCH_LFM2:
2407
+ case LLM_ARCH_NEMOTRON_H:
2341
2408
  return true;
2342
2409
  default:
2343
2410
  return false;
@@ -26,6 +26,7 @@ enum llm_arch {
26
26
  LLM_ARCH_NOMIC_BERT_MOE,
27
27
  LLM_ARCH_NEO_BERT,
28
28
  LLM_ARCH_JINA_BERT_V2,
29
+ LLM_ARCH_JINA_BERT_V3,
29
30
  LLM_ARCH_BLOOM,
30
31
  LLM_ARCH_STABLELM,
31
32
  LLM_ARCH_QWEN,
@@ -48,6 +49,7 @@ enum llm_arch {
48
49
  LLM_ARCH_GEMMA2,
49
50
  LLM_ARCH_GEMMA3,
50
51
  LLM_ARCH_GEMMA3N,
52
+ LLM_ARCH_GEMMA_EMBEDDING,
51
53
  LLM_ARCH_STARCODER2,
52
54
  LLM_ARCH_MAMBA,
53
55
  LLM_ARCH_MAMBA2,
@@ -72,6 +74,7 @@ enum llm_arch {
72
74
  LLM_ARCH_T5ENCODER,
73
75
  LLM_ARCH_JAIS,
74
76
  LLM_ARCH_NEMOTRON,
77
+ LLM_ARCH_NEMOTRON_H,
75
78
  LLM_ARCH_EXAONE,
76
79
  LLM_ARCH_EXAONE4,
77
80
  LLM_ARCH_RWKV6,
@@ -230,6 +233,9 @@ enum llm_kv {
230
233
 
231
234
  LLM_KV_ADAPTER_TYPE,
232
235
  LLM_KV_ADAPTER_LORA_ALPHA,
236
+ LLM_KV_ADAPTER_LORA_TASK_NAME,
237
+ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
238
+ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS,
233
239
 
234
240
  LLM_KV_POSNET_EMBEDDING_LENGTH,
235
241
  LLM_KV_POSNET_BLOCK_COUNT,
@@ -41,7 +41,6 @@ llama_context::llama_context(
41
41
  cparams.yarn_beta_slow = params.yarn_beta_slow;
42
42
  cparams.embeddings = params.embeddings;
43
43
  cparams.offload_kqv = params.offload_kqv;
44
- cparams.flash_attn = params.flash_attn;
45
44
  cparams.no_perf = params.no_perf;
46
45
  cparams.pooling_type = params.pooling_type;
47
46
  cparams.warmup = false;
@@ -86,6 +85,8 @@ llama_context::llama_context(
86
85
  cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
87
86
  }
88
87
 
88
+ cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
89
+
89
90
  // with causal attention, the batch size is limited by the context size
90
91
  cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
91
92
 
@@ -102,16 +103,6 @@ llama_context::llama_context(
102
103
  cparams.op_offload = params.op_offload;
103
104
  cparams.kv_unified = params.kv_unified;
104
105
 
105
- {
106
- const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
107
- supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
108
-
109
- if (!supports_set_rows && !cparams.kv_unified) {
110
- LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
111
- cparams.kv_unified = true;
112
- }
113
- }
114
-
115
106
  {
116
107
  const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
117
108
  graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -129,7 +120,7 @@ llama_context::llama_context(
129
120
  LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
130
121
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
131
122
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
132
- LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
123
+ LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
133
124
  LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
134
125
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
135
126
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -279,28 +270,75 @@ llama_context::llama_context(
279
270
  }
280
271
  }
281
272
 
282
- // reserve worst-case graph
283
- if (!hparams.vocab_only && memory) {
273
+ if (!hparams.vocab_only) {
274
+ llama_memory_context_ptr mctx;
275
+ if (memory) {
276
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
277
+ mctx = memory->init_full();
278
+ if (!mctx) {
279
+ throw std::runtime_error("failed to initialize memory module");
280
+ }
281
+ }
282
+
283
+ cross.v_embd.clear();
284
+
284
285
  const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
285
286
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
286
287
 
288
+ // avoid reserving graphs with zero outputs
289
+ n_outputs = 1;
290
+
287
291
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
288
292
 
293
+ // resolve automatic Flash Attention use
294
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
295
+ auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
296
+ if (!gf) {
297
+ throw std::runtime_error("failed to split graph for Flash Attention check");
298
+ }
299
+
300
+ const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
301
+ bool fa_device_mismatch = false;
302
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
303
+ ggml_tensor * n = ggml_graph_node(gf, i);
304
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
305
+ continue;
306
+ }
307
+ ggml_backend_dev_t device_fa = ggml_backend_get_device(
308
+ ggml_backend_sched_get_tensor_backend(sched.get(), n));
309
+
310
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
311
+ GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
312
+ const int il = std::stoi(n->name + prefix_len);
313
+ ggml_backend_dev_t device_kv = model.dev_layer(il);
314
+ if (device_fa != device_kv) {
315
+ LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
316
+ "is assigned to device %s (usually due to missing support)\n",
317
+ __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
318
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
319
+ fa_device_mismatch = true;
320
+ break;
321
+ }
322
+ }
323
+ if (fa_device_mismatch) {
324
+ cparams.flash_attn = false;
325
+ LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
326
+ if (ggml_is_quantized(params.type_v)) {
327
+ throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
328
+ }
329
+ } else {
330
+ cparams.flash_attn = true;
331
+ LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
332
+ }
333
+ }
334
+
335
+ // reserve worst-case graph
289
336
  int n_splits_pp = -1;
290
337
  int n_nodes_pp = -1;
291
338
 
292
339
  int n_splits_tg = -1;
293
340
  int n_nodes_tg = -1;
294
341
 
295
- // simulate full KV cache
296
-
297
- const auto mctx = memory->init_full();
298
- if (!mctx) {
299
- throw std::runtime_error("failed to initialize KV cache");
300
- }
301
-
302
- cross.v_embd.clear();
303
-
304
342
  // reserve pp (prompt processing) graph first so that buffers are only allocated once
305
343
  {
306
344
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -888,12 +926,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
888
926
  }
889
927
  }
890
928
 
891
- if (!supports_set_rows) {
892
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
893
- // overlap with device computation.
894
- ggml_backend_sched_reset(sched.get());
895
- }
896
-
897
929
  // TODO: hacky solution
898
930
  if (model.arch == LLM_ARCH_T5 && t_embd) {
899
931
  //cross.t_embd = t_embd;
@@ -1056,7 +1088,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1056
1088
  const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1057
1089
 
1058
1090
  if (!res) {
1059
- // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1091
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
1060
1092
  llama_pos pos_min[LLAMA_MAX_SEQ];
1061
1093
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1062
1094
  pos_min[s] = std::numeric_limits<llama_pos>::max();
@@ -1073,7 +1105,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1073
1105
  continue;
1074
1106
  }
1075
1107
 
1076
- LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1108
+ LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1077
1109
 
1078
1110
  memory->seq_rm(s, pos_min[s], -1);
1079
1111
  }
@@ -1224,12 +1256,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1224
1256
  // wait for the computation to finish (automatically done when obtaining the model output)
1225
1257
  //synchronize();
1226
1258
 
1227
- if (!supports_set_rows) {
1228
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1229
- // overlap with device computation.
1230
- ggml_backend_sched_reset(sched.get());
1231
- }
1232
-
1233
1259
  return 0;
1234
1260
  }
1235
1261
 
@@ -1343,8 +1369,9 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
1343
1369
  return static_cast<llm_graph_result *>(gf_res_reserve.get());
1344
1370
  }
1345
1371
 
1346
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1372
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
1347
1373
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1374
+ GGML_ASSERT(n_outputs >= 1);
1348
1375
 
1349
1376
  if (n_tokens % n_seqs != 0) {
1350
1377
  n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
@@ -1378,7 +1405,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1378
1405
  this->n_outputs = save_n_outputs;
1379
1406
 
1380
1407
  // initialize scheduler with the specified graph
1381
- if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1408
+ if (split_only) {
1409
+ ggml_backend_sched_split_graph(sched.get(), gf);
1410
+ } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1382
1411
  LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1383
1412
  return nullptr;
1384
1413
  }
@@ -1857,7 +1886,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1857
1886
  }
1858
1887
 
1859
1888
  if (memory != nullptr) {
1860
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1889
+ LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1861
1890
  memory->state_write(io);
1862
1891
  }
1863
1892
 
@@ -1943,7 +1972,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1943
1972
  }
1944
1973
 
1945
1974
  if (memory) {
1946
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1975
+ LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
1947
1976
 
1948
1977
  memory->state_read(io);
1949
1978
  }
@@ -2228,6 +2257,7 @@ llama_context_params llama_context_default_params() {
2228
2257
  /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2229
2258
  /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2230
2259
  /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2260
+ /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
2231
2261
  /*.rope_freq_base =*/ 0.0f,
2232
2262
  /*.rope_freq_scale =*/ 0.0f,
2233
2263
  /*.yarn_ext_factor =*/ -1.0f,
@@ -2244,7 +2274,6 @@ llama_context_params llama_context_default_params() {
2244
2274
  /*.abort_callback_data =*/ nullptr,
2245
2275
  /*.embeddings =*/ false,
2246
2276
  /*.offload_kqv =*/ true,
2247
- /*.flash_attn =*/ false,
2248
2277
  /*.no_perf =*/ true,
2249
2278
  /*.op_offload =*/ true,
2250
2279
  /*.swa_full =*/ true,
@@ -2272,12 +2301,30 @@ llama_context * llama_init_from_model(
2272
2301
  return nullptr;
2273
2302
  }
2274
2303
 
2275
- if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2304
+ if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
2276
2305
  LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2277
- params.flash_attn = false;
2306
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2307
+ }
2308
+
2309
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2310
+ const uint32_t blck_size = ggml_blck_size(params.type_k);
2311
+ if (model->hparams.n_embd_head_k % blck_size != 0) {
2312
+ LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2313
+ __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2314
+ return nullptr;
2315
+ }
2316
+ }
2317
+
2318
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2319
+ const uint32_t blck_size = ggml_blck_size(params.type_v);
2320
+ if (model->hparams.n_embd_head_v % blck_size != 0) {
2321
+ LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2322
+ __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2323
+ return nullptr;
2324
+ }
2278
2325
  }
2279
2326
 
2280
- if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2327
+ if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
2281
2328
  LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2282
2329
  return nullptr;
2283
2330
  }
@@ -196,7 +196,7 @@ public:
196
196
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
197
197
 
198
198
  // reserve a graph with a dummy ubatch of the specified size
199
- ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
199
+ ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
200
200
 
201
201
  private:
202
202
  llm_graph_params graph_params(
@@ -283,10 +283,6 @@ private:
283
283
 
284
284
  bool has_evaluated_once = false;
285
285
 
286
- // env: LLAMA_SET_ROWS (temporary)
287
- // ref: https://github.com/ggml-org/llama.cpp/pull/14285
288
- bool supports_set_rows = true;
289
-
290
286
  // env: LLAMA_GRAPH_REUSE_DISABLE
291
287
  bool graph_reuse_disable = false;
292
288