@fugood/llama.node 1.1.4 → 1.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. package/lib/binding.ts +8 -0
  2. package/package.json +14 -14
  3. package/src/LlamaContext.cpp +3 -0
  4. package/src/llama.cpp/common/arg.cpp +60 -7
  5. package/src/llama.cpp/common/chat.cpp +6 -6
  6. package/src/llama.cpp/common/common.cpp +1 -0
  7. package/src/llama.cpp/common/common.h +14 -5
  8. package/src/llama.cpp/common/speculative.cpp +135 -54
  9. package/src/llama.cpp/common/speculative.h +8 -1
  10. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  11. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +3196 -0
  12. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +14 -0
  13. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +263 -0
  14. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +11 -0
  15. package/src/llama.cpp/include/llama.h +8 -4
  16. package/src/llama.cpp/src/llama-arch.cpp +40 -0
  17. package/src/llama.cpp/src/llama-arch.h +2 -0
  18. package/src/llama.cpp/src/llama-batch.cpp +1 -1
  19. package/src/llama.cpp/src/llama-chat.cpp +20 -1
  20. package/src/llama.cpp/src/llama-chat.h +1 -0
  21. package/src/llama.cpp/src/llama-context.cpp +11 -2
  22. package/src/llama.cpp/src/llama-context.h +4 -1
  23. package/src/llama.cpp/src/llama-graph.cpp +57 -139
  24. package/src/llama.cpp/src/llama-graph.h +31 -32
  25. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +2 -2
  26. package/src/llama.cpp/src/llama-kv-cache-unified.h +1 -1
  27. package/src/llama.cpp/src/llama-memory-hybrid.cpp +2 -1
  28. package/src/llama.cpp/src/llama-memory-hybrid.h +1 -0
  29. package/src/llama.cpp/src/llama-model.cpp +400 -21
  30. package/src/llama.cpp/src/llama-quant.cpp +3 -3
  31. package/src/llama.cpp/src/llama-vocab.cpp +7 -1
  32. package/src/llama.cpp/src/llama-vocab.h +1 -0
package/lib/binding.ts CHANGED
@@ -65,6 +65,14 @@ export type LlamaModelOptions = {
65
65
  lora?: string
66
66
  lora_scaled?: number
67
67
  lora_list?: { path: string; scaled: number }[]
68
+ /**
69
+ * RoPE base frequency, use 0 to use model default (recommended)
70
+ */
71
+ rope_freq_base?: number
72
+ /**
73
+ * RoPE frequency scaling factor, use 0 to use model default (recommended)
74
+ */
75
+ rope_freq_scale?: number
68
76
  }
69
77
 
70
78
  export type CompletionResponseFormat = {
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.1.4",
4
+ "version": "1.1.5",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
@@ -71,19 +71,19 @@
71
71
  "CMakeLists.txt"
72
72
  ],
73
73
  "optionalDependencies": {
74
- "@fugood/node-llama-linux-x64": "1.1.4",
75
- "@fugood/node-llama-linux-x64-vulkan": "1.1.4",
76
- "@fugood/node-llama-linux-x64-cuda": "1.1.4",
77
- "@fugood/node-llama-linux-arm64": "1.1.4",
78
- "@fugood/node-llama-linux-arm64-vulkan": "1.1.4",
79
- "@fugood/node-llama-linux-arm64-cuda": "1.1.4",
80
- "@fugood/node-llama-win32-x64": "1.1.4",
81
- "@fugood/node-llama-win32-x64-vulkan": "1.1.4",
82
- "@fugood/node-llama-win32-x64-cuda": "1.1.4",
83
- "@fugood/node-llama-win32-arm64": "1.1.4",
84
- "@fugood/node-llama-win32-arm64-vulkan": "1.1.4",
85
- "@fugood/node-llama-darwin-x64": "1.1.4",
86
- "@fugood/node-llama-darwin-arm64": "1.1.4"
74
+ "@fugood/node-llama-linux-x64": "1.1.5",
75
+ "@fugood/node-llama-linux-x64-vulkan": "1.1.5",
76
+ "@fugood/node-llama-linux-x64-cuda": "1.1.5",
77
+ "@fugood/node-llama-linux-arm64": "1.1.5",
78
+ "@fugood/node-llama-linux-arm64-vulkan": "1.1.5",
79
+ "@fugood/node-llama-linux-arm64-cuda": "1.1.5",
80
+ "@fugood/node-llama-win32-x64": "1.1.5",
81
+ "@fugood/node-llama-win32-x64-vulkan": "1.1.5",
82
+ "@fugood/node-llama-win32-x64-cuda": "1.1.5",
83
+ "@fugood/node-llama-win32-arm64": "1.1.5",
84
+ "@fugood/node-llama-win32-arm64-vulkan": "1.1.5",
85
+ "@fugood/node-llama-darwin-x64": "1.1.5",
86
+ "@fugood/node-llama-darwin-arm64": "1.1.5"
87
87
  },
88
88
  "devDependencies": {
89
89
  "@babel/preset-env": "^7.24.4",
@@ -250,6 +250,9 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
250
250
  params.kv_unified = get_option<bool>(options, "kv_unified", false);
251
251
  params.swa_full = get_option<bool>(options, "swa_full", false);
252
252
 
253
+ params.rope_freq_base = get_option<float>(options, "rope_freq_base", 0.0f);
254
+ params.rope_freq_scale = get_option<float>(options, "rope_freq_scale", 0.0f);
255
+
253
256
  params.use_mlock = get_option<bool>(options, "use_mlock", false);
254
257
  params.use_mmap = get_option<bool>(options, "use_mmap", true);
255
258
  params.numa =
@@ -977,6 +977,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
977
977
  for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
978
978
  string_process_escapes(seq_breaker);
979
979
  }
980
+ for (auto & pair : params.speculative.replacements) {
981
+ string_process_escapes(pair.first);
982
+ string_process_escapes(pair.second);
983
+ }
980
984
  }
981
985
 
982
986
  if (!params.kv_overrides.empty()) {
@@ -2091,6 +2095,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2091
2095
  params.no_kv_offload = true;
2092
2096
  }
2093
2097
  ).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
2098
+ add_opt(common_arg(
2099
+ {"-nr", "--no-repack"},
2100
+ "disable weight repacking",
2101
+ [](common_params & params) {
2102
+ params.no_extra_bufts = true;
2103
+ }
2104
+ ).set_env("LLAMA_ARG_NO_REPACK"));
2094
2105
  add_opt(common_arg(
2095
2106
  {"-ctk", "--cache-type-k"}, "TYPE",
2096
2107
  string_format(
@@ -2369,6 +2380,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2369
2380
  }
2370
2381
  }
2371
2382
  ));
2383
+ add_opt(common_arg(
2384
+ {"--cpu-moe"},
2385
+ "use CPU for Mixture of Experts (MoE) weights",
2386
+ [](common_params & params) {
2387
+ params.tensor_buft_overrides.push_back({"\\.ffn_up_exps\\.weight$", ggml_backend_cpu_buffer_type()});
2388
+ params.tensor_buft_overrides.push_back({"\\.ffn_down_exps\\.weight$", ggml_backend_cpu_buffer_type()});
2389
+ params.tensor_buft_overrides.push_back({"\\.ffn_gate_exps\\.weight$", ggml_backend_cpu_buffer_type()});
2390
+ }
2391
+ ).set_env("LLAMA_ARG_CPU_MOE"));
2372
2392
  add_opt(common_arg(
2373
2393
  {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
2374
2394
  "number of layers to store in VRAM",
@@ -2627,6 +2647,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2627
2647
  params.n_out_freq = value;
2628
2648
  }
2629
2649
  ).set_examples({LLAMA_EXAMPLE_IMATRIX}));
2650
+ add_opt(common_arg(
2651
+ {"--output-format"}, "{gguf,dat}",
2652
+ string_format("output format for imatrix file (default: %s)", params.imat_dat ? "dat" : "gguf"),
2653
+ [](common_params & params, const std::string & value) {
2654
+ /**/ if (value == "gguf") { params.imat_dat = false; }
2655
+ else if (value == "dat") { params.imat_dat = true; }
2656
+ else { throw std::invalid_argument("invalid output format"); }
2657
+ }
2658
+ ).set_examples({LLAMA_EXAMPLE_IMATRIX}));
2630
2659
  add_opt(common_arg(
2631
2660
  {"--save-frequency"}, "N",
2632
2661
  string_format("save an imatrix copy every N iterations (default: %d)", params.n_save_freq),
@@ -3249,6 +3278,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
3249
3278
  params.speculative.model.path = value;
3250
3279
  }
3251
3280
  ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
3281
+ add_opt(common_arg(
3282
+ {"--spec-replace"}, "TARGET", "DRAFT",
3283
+ "translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
3284
+ [](common_params & params, const std::string & tgt, const std::string & dft) {
3285
+ params.speculative.replacements.push_back({ tgt, dft });
3286
+ }
3287
+ ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
3252
3288
  add_opt(common_arg(
3253
3289
  {"-ctkd", "--cache-type-k-draft"}, "TYPE",
3254
3290
  string_format(
@@ -3438,12 +3474,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
3438
3474
  }
3439
3475
  ).set_examples({LLAMA_EXAMPLE_SERVER}));
3440
3476
 
3441
- // diffusion parameters
3442
3477
  add_opt(common_arg(
3443
3478
  { "--diffusion-steps" }, "N",
3444
3479
  string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
3445
3480
  [](common_params & params, int value) { params.diffusion.steps = value; }
3446
3481
  ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3482
+ add_opt(common_arg(
3483
+ { "--diffusion-visual" },
3484
+ string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3485
+ params.diffusion.visual_mode ? "true" : "false"),
3486
+ [](common_params & params) { params.diffusion.visual_mode = true; }
3487
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3488
+
3447
3489
  add_opt(common_arg(
3448
3490
  { "--diffusion-eps" }, "F",
3449
3491
  string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
@@ -3451,21 +3493,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
3451
3493
  ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3452
3494
  add_opt(common_arg(
3453
3495
  { "--diffusion-algorithm" }, "N",
3454
- string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
3496
+ string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
3455
3497
  params.diffusion.algorithm),
3456
3498
  [](common_params & params, int value) { params.diffusion.algorithm = value; }
3457
3499
  ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3458
3500
  add_opt(common_arg(
3459
3501
  { "--diffusion-alg-temp" }, "F",
3460
- string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
3502
+ string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
3461
3503
  [](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
3462
3504
  ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3505
+
3463
3506
  add_opt(common_arg(
3464
- { "--diffusion-visual" },
3465
- string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3466
- params.diffusion.visual_mode ? "true" : "false"),
3467
- [](common_params & params) { params.diffusion.visual_mode = true; }
3507
+ { "--diffusion-block-length" }, "N",
3508
+ string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
3509
+ [](common_params & params, int value) { params.diffusion.block_length = value; }
3510
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3511
+ add_opt(common_arg(
3512
+ { "--diffusion-cfg-scale" }, "F",
3513
+ string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
3514
+ [](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
3468
3515
  ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3516
+ add_opt(common_arg(
3517
+ { "--diffusion-add-gumbel-noise" }, "F",
3518
+ string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
3519
+ [](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
3520
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3521
+
3469
3522
 
3470
3523
  return ctx_arg;
3471
3524
  }
@@ -1635,7 +1635,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
1635
1635
  "|<function name=\"([^\"]+)\">" // match 5 (function name again)
1636
1636
  );
1637
1637
 
1638
- if (auto res = builder.try_find_regex(open_regex)) {
1638
+ while (auto res = builder.try_find_regex(open_regex)) {
1639
1639
  const auto & block_start = res->groups[1];
1640
1640
  std::string block_end = block_start.empty() ? "" : "```";
1641
1641
 
@@ -1657,7 +1657,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
1657
1657
  builder.consume_literal(block_end);
1658
1658
  builder.consume_spaces();
1659
1659
  }
1660
- builder.add_content(builder.consume_rest());
1661
1660
  } else {
1662
1661
  throw common_chat_msg_partial_exception("failed to parse tool call");
1663
1662
  }
@@ -1682,11 +1681,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
1682
1681
  builder.consume_spaces();
1683
1682
  }
1684
1683
  }
1685
- builder.add_content(builder.consume_rest());
1686
1684
  }
1687
- } else {
1688
- builder.add_content(builder.consume_rest());
1689
1685
  }
1686
+
1687
+ builder.add_content(builder.consume_rest());
1690
1688
  }
1691
1689
 
1692
1690
  static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1933,6 +1931,8 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
1933
1931
  }
1934
1932
  }
1935
1933
  auto msg = builder.result();
1936
- LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
1934
+ if (!is_partial) {
1935
+ LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
1936
+ }
1937
1937
  return msg;
1938
1938
  }
@@ -1123,6 +1123,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
1123
1123
  mparams.use_mmap = params.use_mmap;
1124
1124
  mparams.use_mlock = params.use_mlock;
1125
1125
  mparams.check_tensors = params.check_tensors;
1126
+ mparams.use_extra_bufts = !params.no_extra_bufts;
1126
1127
 
1127
1128
  if (params.kv_overrides.empty()) {
1128
1129
  mparams.kv_overrides = NULL;
@@ -201,6 +201,7 @@ struct common_params_speculative {
201
201
  int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
202
202
  float p_split = 0.1f; // speculative decoding split probability
203
203
  float p_min = 0.75f; // minimum speculative decoding probability (greedy)
204
+ std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
204
205
 
205
206
  ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
206
207
  ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
@@ -220,11 +221,17 @@ struct common_params_vocoder {
220
221
  };
221
222
 
222
223
  struct common_params_diffusion {
223
- int32_t steps = 64; // number of diffusion steps
224
- float eps = 1e-3f; // epsilon for timesteps
225
- int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
226
- float alg_temp = 0.0f; // algorithm temperature
227
- bool visual_mode = false; // show progressive diffusion on screen
224
+ int32_t steps = 128;
225
+ bool visual_mode = false;
226
+
227
+ float eps = 0; // epsilon for timesteps
228
+ int32_t block_length = 0; // block length for generation
229
+
230
+ int32_t algorithm = 4; // default algorithm: low-confidence
231
+ float alg_temp = 0.0f; // algorithm temperature
232
+
233
+ float cfg_scale = 0; // classifier-free guidance scale
234
+ bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
228
235
  };
229
236
 
230
237
  enum common_reasoning_format {
@@ -353,6 +360,7 @@ struct common_params {
353
360
  bool warmup = true; // warmup run
354
361
  bool check_tensors = false; // validate tensor data
355
362
  bool no_op_offload = false; // globally disable offload host tensor operations to device
363
+ bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
356
364
 
357
365
  bool single_turn = false; // single turn chat conversation
358
366
 
@@ -432,6 +440,7 @@ struct common_params {
432
440
  int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
433
441
  int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
434
442
  int32_t i_chunk = 0; // start processing from this chunk
443
+ bool imat_dat = false; // whether the legacy imatrix.dat format should be output
435
444
 
436
445
  bool process_output = false; // collect data for the output tensor
437
446
  bool compute_ppl = true; // whether to compute perplexity
@@ -1,30 +1,39 @@
1
1
  #include "speculative.h"
2
2
 
3
+ #include "ggml.h"
4
+ #include "llama.h"
3
5
  #include "log.h"
4
6
  #include "common.h"
5
7
  #include "sampling.h"
6
8
 
7
9
  #include <cstring>
8
10
  #include <algorithm>
11
+ #include <map>
9
12
 
10
13
  #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
11
14
  #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
12
15
 
13
16
  struct common_speculative {
14
- struct llama_context * ctx;
17
+ struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
18
+ struct llama_context * ctx_dft;
15
19
  struct common_sampler * smpl;
16
20
 
17
21
  llama_batch batch;
18
- llama_tokens prompt;
22
+ llama_tokens prompt_dft;
23
+ bool vocab_dft_compatible = true; // whether retokenization is needed
24
+ std::map<std::string, std::string> tgt_dft_replacements = {};
19
25
  };
20
26
 
21
27
  struct common_speculative * common_speculative_init(
28
+ struct llama_context * ctx_tgt,
22
29
  struct llama_context * ctx_dft) {
23
30
  auto * result = new common_speculative {
24
- /* .ctx = */ ctx_dft,
25
- /* .smpl = */ nullptr,
26
- /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
27
- /* .prompt = */ {},
31
+ /* .ctx_tgt = */ ctx_tgt,
32
+ /* .ctx_dft = */ ctx_dft,
33
+ /* .smpl = */ nullptr,
34
+ /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
35
+ /* .prompt_dft = */ {},
36
+ /* .vocab_dft_compatible = */ false,
28
37
  };
29
38
 
30
39
  // TODO: optimize or pass from outside?
@@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init(
59
68
  }
60
69
  #endif
61
70
 
71
+ result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
72
+ LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
73
+
62
74
  return result;
63
75
  }
64
76
 
@@ -75,8 +87,8 @@ void common_speculative_free(struct common_speculative * spec) {
75
87
  }
76
88
 
77
89
  bool common_speculative_are_compatible(
78
- const struct llama_context * ctx_tgt,
79
- const struct llama_context * ctx_dft) {
90
+ const struct llama_context * ctx_tgt,
91
+ const struct llama_context * ctx_dft) {
80
92
  const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
81
93
  const struct llama_model * model_dft = llama_get_model(ctx_dft);
82
94
 
@@ -90,31 +102,32 @@ bool common_speculative_are_compatible(
90
102
  LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
91
103
 
92
104
  if (vocab_type_tgt != vocab_type_dft) {
93
- LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
94
- "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
105
+ LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
106
+ LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
95
107
  return false;
96
108
  }
97
109
 
98
- if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
110
+ if (
111
+ llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
99
112
  llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
100
113
  llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
101
- llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
102
- LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
103
- LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
104
- LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
114
+ llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
115
+ ) {
116
+ LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
105
117
  return false;
106
118
  }
107
119
 
108
120
  {
109
121
  const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
110
122
  const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
111
-
112
- const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
123
+ const int vocab_diff = n_vocab_tgt > n_vocab_dft
124
+ ? n_vocab_tgt - n_vocab_dft
125
+ : n_vocab_dft - n_vocab_tgt;
113
126
 
114
127
  if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
115
- LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
116
- "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
117
- __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
128
+ LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
129
+ LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
130
+ n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
118
131
  return false;
119
132
  }
120
133
 
@@ -122,8 +135,8 @@ bool common_speculative_are_compatible(
122
135
  const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
123
136
  const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
124
137
  if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
125
- LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
126
- "token %d content differs - target '%s', draft '%s'\n", __func__, i,
138
+ LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
139
+ LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
127
140
  common_token_to_piece(ctx_tgt, i).c_str(),
128
141
  common_token_to_piece(ctx_dft, i).c_str());
129
142
  return false;
@@ -134,32 +147,93 @@ bool common_speculative_are_compatible(
134
147
  return true;
135
148
  }
136
149
 
150
+ void common_speculative_add_replacement_tgt_dft(
151
+ struct common_speculative * spec,
152
+ const char *source, const char *dest) {
153
+ spec->tgt_dft_replacements[source] = dest;
154
+ }
155
+
156
+ static std::string replace_to_dft(
157
+ struct common_speculative * spec,
158
+ const std::string& input) {
159
+ std::string result = input;
160
+ for (const auto & pair : spec->tgt_dft_replacements) {
161
+ size_t pos = result.find(pair.first);
162
+ while (pos != std::string::npos) {
163
+ result.replace(pos, pair.first.length(), pair.second);
164
+ pos = result.find(pair.first, pos + pair.second.length());
165
+ }
166
+ }
167
+ return result;
168
+ }
169
+
170
+ static std::string replace_to_tgt(
171
+ struct common_speculative * spec,
172
+ const std::string& input) {
173
+ std::string result = input;
174
+ for (const auto& pair : spec->tgt_dft_replacements) {
175
+ size_t pos = result.find(pair.second);
176
+ while (pos != std::string::npos) {
177
+ result.replace(pos, pair.second.length(), pair.first);
178
+ pos = result.find(pair.second, pos + pair.first.length());
179
+ }
180
+ }
181
+ return result;
182
+ }
183
+
184
+
137
185
  llama_tokens common_speculative_gen_draft(
138
186
  struct common_speculative * spec,
139
187
  struct common_speculative_params params,
140
- const llama_tokens & prompt_tgt,
188
+ const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
141
189
  llama_token id_last) {
142
190
  auto & batch = spec->batch;
143
- auto & ctx = spec->ctx;
191
+ auto & ctx_tgt = spec->ctx_tgt;
192
+ auto & ctx_dft = spec->ctx_dft;
144
193
  auto & smpl = spec->smpl;
145
- auto & prompt = spec->prompt;
194
+ auto & prompt_dft = spec->prompt_dft;
146
195
 
147
- auto * mem = llama_get_memory(ctx);
196
+ auto * mem_dft = llama_get_memory(ctx_dft);
148
197
 
149
198
  int reuse_i = 0;
150
199
  int reuse_n = 0;
151
200
 
152
- const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
201
+ const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
202
+
203
+ llama_tokens prompt_tgt_draft_model;
204
+ if (!spec->vocab_dft_compatible) {
205
+ std::string text;
206
+ text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
207
+ text = replace_to_dft(spec, text);
208
+ LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
209
+ prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
210
+
211
+ // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
212
+ const auto * model_tgt = llama_get_model(ctx_tgt);
213
+ const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
214
+
215
+ int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
216
+ GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
217
+ text.resize(-n_chars);
218
+ llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
219
+ text = replace_to_dft(spec, text);
220
+
221
+ LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
222
+ id_last = common_tokenize(ctx_dft, text, false, true)[0];
223
+ }
224
+ // prompt_tgt's tokens will always be compatible with ctx_dft
225
+ const llama_tokens &prompt_tgt =
226
+ spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
153
227
 
154
228
  const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
155
229
 
156
230
  // reuse as much as possible from the old draft context
157
231
  // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
158
- for (int i = 0; i < (int) prompt.size(); ++i) {
232
+ for (int i = 0; i < (int) prompt_dft.size(); ++i) {
159
233
  int cur = 0;
160
234
  while (i_start + cur < (int) prompt_tgt.size() &&
161
- i + cur < (int) prompt.size() &&
162
- prompt_tgt[i_start + cur] == prompt[i + cur]) {
235
+ i + cur < (int) prompt_dft.size() &&
236
+ prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
163
237
  cur++;
164
238
  }
165
239
 
@@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft(
169
243
  }
170
244
  }
171
245
 
172
- LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
246
+ LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
173
247
 
174
248
  llama_tokens result;
175
249
  result.reserve(params.n_draft);
176
250
 
177
251
  if (reuse_n == 0) {
178
- llama_memory_clear(mem, false);
179
-
180
- prompt.clear();
252
+ llama_memory_clear(mem_dft, false);
253
+ prompt_dft.clear();
181
254
  } else {
182
255
  // this happens when a previous draft has been discarded (for example, due to being too small), but the
183
256
  // target model agreed with it. in this case, we simply pass back the previous results to save compute
184
- if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
185
- for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
186
- result.push_back(prompt[i]);
257
+ if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
258
+ for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
259
+ result.push_back(prompt_dft[i]);
187
260
 
188
261
  if (params.n_draft <= (int) result.size()) {
189
262
  break;
@@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft(
194
267
  }
195
268
 
196
269
  if (reuse_i > 0) {
197
- llama_memory_seq_rm (mem, 0, 0, reuse_i);
198
- llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);
270
+ llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
271
+ llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
199
272
 
200
- prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
273
+ prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
201
274
  }
202
275
 
203
- if (reuse_n < (int) prompt.size()) {
204
- llama_memory_seq_rm (mem, 0, reuse_n, -1);
205
-
206
- prompt.erase(prompt.begin() + reuse_n, prompt.end());
276
+ if (reuse_n < (int) prompt_dft.size()) {
277
+ llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
278
+ prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
207
279
  }
208
280
  }
209
281
 
@@ -214,28 +286,28 @@ llama_tokens common_speculative_gen_draft(
214
286
  //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
215
287
  common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
216
288
 
217
- prompt.push_back(prompt_tgt[i]);
289
+ prompt_dft.push_back(prompt_tgt[i]);
218
290
  }
219
291
 
220
292
  // we should rarely end-up here during normal decoding
221
293
  if (batch.n_tokens > 0) {
222
294
  //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
223
295
 
224
- llama_decode(ctx, batch);
296
+ llama_decode(ctx_dft, batch);
225
297
  }
226
298
 
227
- const llama_pos n_past = prompt.size();
299
+ const llama_pos n_past = prompt_dft.size();
228
300
 
229
301
  LOG_DBG("%s: n_past = %d\n", __func__, n_past);
230
302
 
231
303
  common_batch_clear(batch);
232
304
  common_batch_add (batch, id_last, n_past, { 0 }, true);
233
305
 
234
- prompt.push_back(id_last);
306
+ prompt_dft.push_back(id_last);
235
307
 
236
- //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
308
+ LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
237
309
 
238
- llama_decode(ctx, batch);
310
+ llama_decode(ctx_dft, batch);
239
311
 
240
312
  common_sampler_reset(smpl);
241
313
 
@@ -243,13 +315,13 @@ llama_tokens common_speculative_gen_draft(
243
315
  for (int i = 0; i < params.n_draft; ++i) {
244
316
  common_batch_clear(batch);
245
317
 
246
- common_sampler_sample(smpl, ctx, 0, true);
318
+ common_sampler_sample(smpl, ctx_dft, 0, true);
247
319
 
248
320
  const auto * cur_p = common_sampler_get_candidates(smpl);
249
321
 
250
322
  for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
251
323
  LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
252
- k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
324
+ k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
253
325
  }
254
326
 
255
327
  // add drafted token for each sequence
@@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft(
271
343
  common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
272
344
 
273
345
  // evaluate the drafted tokens on the draft model
274
- llama_decode(ctx, batch);
346
+ llama_decode(ctx_dft, batch);
275
347
 
276
- prompt.push_back(id);
348
+ prompt_dft.push_back(id);
277
349
  }
278
350
 
351
+ if (!spec->vocab_dft_compatible) {
352
+ std::string detokenized = common_detokenize(ctx_dft, result, true);
353
+ detokenized = replace_to_tgt(spec, detokenized);
354
+ LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
355
+ result = common_tokenize(ctx_tgt, detokenized, false, true);
356
+ if (result.size() > (size_t)params.n_draft) {
357
+ result.resize(params.n_draft);
358
+ }
359
+ }
279
360
  return result;
280
361
  }