@fugood/llama.node 0.3.8 → 0.3.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.js +2 -2
  18. package/lib/binding.ts +52 -8
  19. package/lib/index.ts +3 -1
  20. package/package.json +8 -1
  21. package/src/LlamaCompletionWorker.cpp +33 -6
  22. package/src/LlamaCompletionWorker.h +3 -1
  23. package/src/LlamaContext.cpp +387 -28
  24. package/src/LlamaContext.h +5 -0
  25. package/src/common.hpp +19 -2
  26. package/src/llama.cpp/.github/workflows/build.yml +289 -107
  27. package/src/llama.cpp/.github/workflows/close-issue.yml +1 -1
  28. package/src/llama.cpp/.github/workflows/docker.yml +2 -1
  29. package/src/llama.cpp/.github/workflows/server.yml +25 -2
  30. package/src/llama.cpp/CMakeLists.txt +10 -19
  31. package/src/llama.cpp/cmake/build-info.cmake +1 -1
  32. package/src/llama.cpp/common/CMakeLists.txt +32 -0
  33. package/src/llama.cpp/common/arg.cpp +66 -16
  34. package/src/llama.cpp/common/chat-template.hpp +515 -0
  35. package/src/llama.cpp/common/chat.cpp +966 -0
  36. package/src/llama.cpp/common/chat.hpp +52 -0
  37. package/src/llama.cpp/common/common.cpp +159 -36
  38. package/src/llama.cpp/common/common.h +56 -14
  39. package/src/llama.cpp/common/json-schema-to-grammar.cpp +46 -66
  40. package/src/llama.cpp/common/json-schema-to-grammar.h +15 -1
  41. package/src/llama.cpp/common/llguidance.cpp +270 -0
  42. package/src/llama.cpp/common/log.cpp +1 -10
  43. package/src/llama.cpp/common/log.h +10 -0
  44. package/src/llama.cpp/common/minja.hpp +2868 -0
  45. package/src/llama.cpp/common/sampling.cpp +22 -1
  46. package/src/llama.cpp/common/sampling.h +3 -0
  47. package/src/llama.cpp/docs/build.md +54 -9
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +12 -2
  49. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +1 -1
  50. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  51. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +59 -0
  52. package/src/llama.cpp/examples/llava/clip.cpp +133 -14
  53. package/src/llama.cpp/examples/llava/clip.h +2 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +22 -8
  55. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +9 -1
  56. package/src/llama.cpp/examples/main/main.cpp +26 -25
  57. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +136 -137
  58. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +18 -4
  59. package/src/llama.cpp/examples/run/run.cpp +224 -69
  60. package/src/llama.cpp/examples/server/server.cpp +252 -81
  61. package/src/llama.cpp/examples/server/utils.hpp +73 -21
  62. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +6 -4
  63. package/src/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt +11 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +78 -1
  65. package/src/llama.cpp/ggml/include/ggml.h +1 -1
  66. package/src/llama.cpp/ggml/src/CMakeLists.txt +21 -4
  67. package/src/llama.cpp/ggml/src/ggml-alloc.c +1 -13
  68. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +91 -78
  69. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +7 -7
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -1
  71. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +46 -0
  73. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +16 -1
  74. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +28 -8
  76. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +5 -7
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +33 -23
  78. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +1 -5
  79. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +323 -121
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  81. package/src/llama.cpp/ggml/src/ggml.c +23 -13
  82. package/src/llama.cpp/include/llama.h +14 -1
  83. package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +112 -0
  84. package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +46 -0
  85. package/src/llama.cpp/src/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/src/llama-arch.cpp +7 -2
  87. package/src/llama.cpp/src/llama-arch.h +3 -1
  88. package/src/llama.cpp/src/llama-chat.cpp +11 -2
  89. package/src/llama.cpp/src/llama-chat.h +1 -0
  90. package/src/llama.cpp/src/llama-grammar.cpp +86 -6
  91. package/src/llama.cpp/src/llama-grammar.h +22 -1
  92. package/src/llama.cpp/src/llama-mmap.cpp +1 -0
  93. package/src/llama.cpp/src/llama-model-loader.cpp +1 -1
  94. package/src/llama.cpp/src/llama-model.cpp +76 -6
  95. package/src/llama.cpp/src/llama-sampling.cpp +47 -4
  96. package/src/llama.cpp/src/llama-vocab.cpp +10 -4
  97. package/src/llama.cpp/src/llama.cpp +181 -123
  98. package/src/llama.cpp/tests/CMakeLists.txt +4 -0
  99. package/src/llama.cpp/tests/test-backend-ops.cpp +158 -57
  100. package/src/llama.cpp/tests/test-chat-template.cpp +154 -31
  101. package/src/llama.cpp/tests/test-chat.cpp +607 -0
  102. package/src/llama.cpp/tests/test-grammar-integration.cpp +2 -2
  103. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +1140 -0
  104. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +1 -1
  105. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +0 -32
@@ -14,7 +14,7 @@
14
14
  // mime type for sending response
15
15
  #define MIMETYPE_JSON "application/json; charset=utf-8"
16
16
 
17
- // auto generated files (update with ./deps.sh)
17
+ // auto generated files (see README.md for details)
18
18
  #include "index.html.gz.hpp"
19
19
  #include "loading.html.hpp"
20
20
 
@@ -113,10 +113,11 @@ struct slot_params {
113
113
  struct common_params_speculative speculative;
114
114
 
115
115
  // OAI-compat fields
116
- bool verbose = false;
117
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
118
- std::string oaicompat_model;
119
- std::string oaicompat_cmpl_id;
116
+ bool verbose = false;
117
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
118
+ std::string oaicompat_model;
119
+ std::string oaicompat_cmpl_id;
120
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
120
121
 
121
122
  json to_json() const {
122
123
  std::vector<std::string> samplers;
@@ -130,6 +131,11 @@ struct slot_params {
130
131
  lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
131
132
  }
132
133
 
134
+ std::vector<std::string> grammar_trigger_words;
135
+ for (const auto & trigger : sampling.grammar_trigger_words) {
136
+ grammar_trigger_words.push_back(trigger.word);
137
+ }
138
+
133
139
  return json {
134
140
  {"n_predict", n_predict}, // Server configured n_predict
135
141
  {"seed", sampling.seed},
@@ -164,6 +170,9 @@ struct slot_params {
164
170
  {"n_probs", sampling.n_probs},
165
171
  {"min_keep", sampling.min_keep},
166
172
  {"grammar", sampling.grammar},
173
+ {"grammar_trigger_words", grammar_trigger_words},
174
+ {"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
175
+ {"preserved_tokens", sampling.preserved_tokens},
167
176
  {"samplers", samplers},
168
177
  {"speculative.n_max", speculative.n_max},
169
178
  {"speculative.n_min", speculative.n_min},
@@ -267,6 +276,11 @@ struct server_task {
267
276
  params.speculative.n_min = std::max(params.speculative.n_min, 2);
268
277
  params.speculative.n_max = std::max(params.speculative.n_max, 0);
269
278
 
279
+ // Use OpenAI API logprobs only if n_probs wasn't provided
280
+ if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
281
+ params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
282
+ }
283
+
270
284
  if (data.contains("lora")) {
271
285
  if (data.at("lora").is_array()) {
272
286
  params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
@@ -320,12 +334,64 @@ struct server_task {
320
334
  if (data.contains("json_schema") && !data.contains("grammar")) {
321
335
  try {
322
336
  auto schema = json_value(data, "json_schema", json::object());
323
- params.sampling.grammar = json_schema_to_grammar(schema);
337
+ LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str());
338
+ params.sampling.grammar = json_schema_to_grammar(schema);
339
+ LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
324
340
  } catch (const std::exception & e) {
325
341
  throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
326
342
  }
327
343
  } else {
328
- params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
344
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
345
+ LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
346
+ params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
347
+ LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
348
+ }
349
+
350
+ {
351
+ auto it = data.find("chat_format");
352
+ if (it != data.end()) {
353
+ params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
354
+ LOG_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
355
+ } else {
356
+ params.oaicompat_chat_format = defaults.oaicompat_chat_format;
357
+ }
358
+ }
359
+
360
+ {
361
+ const auto grammar_triggers = data.find("grammar_triggers");
362
+ if (grammar_triggers != data.end()) {
363
+ for (const auto & t : *grammar_triggers) {
364
+ common_grammar_trigger trigger;
365
+ trigger.word = t.at("word");
366
+ trigger.at_start = t.at("at_start");
367
+
368
+ auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
369
+ if (ids.size() == 1) {
370
+ LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
371
+ params.sampling.grammar_trigger_tokens.push_back(ids[0]);
372
+ params.sampling.preserved_tokens.insert(ids[0]);
373
+ continue;
374
+ }
375
+ LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
376
+ params.sampling.grammar_trigger_words.push_back(trigger);
377
+ }
378
+ }
379
+ const auto preserved_tokens = data.find("preserved_tokens");
380
+ if (preserved_tokens != data.end()) {
381
+ for (const auto & t : *preserved_tokens) {
382
+ auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
383
+ if (ids.size() == 1) {
384
+ LOG_DBG("Preserved token: %d\n", ids[0]);
385
+ params.sampling.preserved_tokens.insert(ids[0]);
386
+ } else {
387
+ // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
388
+ LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
389
+ }
390
+ }
391
+ }
392
+ if (params.sampling.grammar_lazy) {
393
+ GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
394
+ }
329
395
  }
330
396
 
331
397
  {
@@ -377,22 +443,12 @@ struct server_task {
377
443
  }
378
444
 
379
445
  {
380
- const auto & samplers = data.find("samplers");
446
+ const auto samplers = data.find("samplers");
381
447
  if (samplers != data.end()) {
382
448
  if (samplers->is_array()) {
383
- std::vector<std::string> sampler_names;
384
- for (const auto & name : *samplers) {
385
- if (name.is_string()) {
386
- sampler_names.emplace_back(name);
387
- }
388
- }
389
- params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
449
+ params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
390
450
  } else if (samplers->is_string()){
391
- std::string sampler_string;
392
- for (const auto & name : *samplers) {
393
- sampler_string += name;
394
- }
395
- params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
451
+ params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
396
452
  }
397
453
  } else {
398
454
  params.sampling.samplers = defaults.sampling.samplers;
@@ -539,7 +595,7 @@ struct completion_token_output {
539
595
  struct server_task_result_cmpl_final : server_task_result {
540
596
  int index = 0;
541
597
 
542
- std::string content;
598
+ std::string content;
543
599
  llama_tokens tokens;
544
600
 
545
601
  bool stream;
@@ -561,10 +617,11 @@ struct server_task_result_cmpl_final : server_task_result {
561
617
  slot_params generation_params;
562
618
 
563
619
  // OAI-compat fields
564
- bool verbose = false;
565
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
566
- std::string oaicompat_model;
567
- std::string oaicompat_cmpl_id;
620
+ bool verbose = false;
621
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
622
+ std::string oaicompat_model;
623
+ std::string oaicompat_cmpl_id;
624
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
568
625
 
569
626
  virtual int get_index() override {
570
627
  return index;
@@ -658,18 +715,44 @@ struct server_task_result_cmpl_final : server_task_result {
658
715
 
659
716
  json to_json_oaicompat_chat() {
660
717
  std::string finish_reason = "length";
718
+ common_chat_msg msg;
661
719
  if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
662
- finish_reason = "stop";
720
+ LOG_DBG("Parsing chat message: %s\n", content.c_str());
721
+ msg = common_chat_parse(content, oaicompat_chat_format);
722
+ finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
723
+ } else {
724
+ msg.content = content;
725
+ }
726
+
727
+ json tool_calls;
728
+ if (!msg.tool_calls.empty()) {
729
+ tool_calls = json::array();
730
+ for (const auto & tc : msg.tool_calls) {
731
+ tool_calls.push_back({
732
+ {"type", "function"},
733
+ {"function", {
734
+ {"name", tc.name},
735
+ {"arguments", tc.arguments},
736
+ }},
737
+ {"id", tc.id},
738
+ });
739
+ }
663
740
  }
664
741
 
665
- json choice = json{
742
+ json message {
743
+ {"content", msg.content},
744
+ {"tool_calls", tool_calls},
745
+ {"role", "assistant"},
746
+ };
747
+ if (!msg.tool_plan.empty()) {
748
+ message["tool_plan"] = msg.tool_plan;
749
+ }
750
+
751
+ json choice {
666
752
  {"finish_reason", finish_reason},
667
753
  {"index", 0},
668
- {"message", json {
669
- {"content", content},
670
- {"role", "assistant"}
671
- }
672
- }};
754
+ {"message", message},
755
+ };
673
756
 
674
757
  if (!stream && probs_output.size() > 0) {
675
758
  choice["logprobs"] = json{
@@ -711,7 +794,7 @@ struct server_task_result_cmpl_final : server_task_result {
711
794
  finish_reason = "stop";
712
795
  }
713
796
 
714
- json choice = json{
797
+ json choice = json {
715
798
  {"finish_reason", finish_reason},
716
799
  {"index", 0},
717
800
  {"delta", json::object()}
@@ -1186,6 +1269,8 @@ struct server_slot {
1186
1269
 
1187
1270
  llama_token sampled;
1188
1271
 
1272
+ common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1273
+
1189
1274
  // stats
1190
1275
  size_t n_sent_text = 0; // number of sent text character
1191
1276
 
@@ -1422,6 +1507,10 @@ struct server_queue {
1422
1507
  int post(server_task task, bool front = false) {
1423
1508
  std::unique_lock<std::mutex> lock(mutex_tasks);
1424
1509
  GGML_ASSERT(task.id != -1);
1510
+ // if this is cancel task make sure to clean up pending tasks
1511
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
1512
+ cleanup_pending_task(task.id_target);
1513
+ }
1425
1514
  QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
1426
1515
  if (front) {
1427
1516
  queue_tasks.push_front(std::move(task));
@@ -1439,6 +1528,10 @@ struct server_queue {
1439
1528
  if (task.id == -1) {
1440
1529
  task.id = id++;
1441
1530
  }
1531
+ // if this is cancel task make sure to clean up pending tasks
1532
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
1533
+ cleanup_pending_task(task.id_target);
1534
+ }
1442
1535
  QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
1443
1536
  if (front) {
1444
1537
  queue_tasks.push_front(std::move(task));
@@ -1539,6 +1632,20 @@ struct server_queue {
1539
1632
  }
1540
1633
  }
1541
1634
  }
1635
+
1636
+ private:
1637
+ void cleanup_pending_task(int id_target) {
1638
+ // no need lock because this is called exclusively by post()
1639
+ auto rm_func = [id_target](const server_task & task) {
1640
+ return task.id_target == id_target;
1641
+ };
1642
+ queue_tasks.erase(
1643
+ std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
1644
+ queue_tasks.end());
1645
+ queue_tasks_deferred.erase(
1646
+ std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
1647
+ queue_tasks_deferred.end());
1648
+ }
1542
1649
  };
1543
1650
 
1544
1651
  struct server_response {
@@ -1574,6 +1681,12 @@ struct server_response {
1574
1681
 
1575
1682
  std::unique_lock<std::mutex> lock(mutex_results);
1576
1683
  waiting_task_ids.erase(id_task);
1684
+ // make sure to clean up all pending results
1685
+ queue_results.erase(
1686
+ std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
1687
+ return res->id == id_task;
1688
+ }),
1689
+ queue_results.end());
1577
1690
  }
1578
1691
 
1579
1692
  void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
@@ -1593,7 +1706,7 @@ struct server_response {
1593
1706
  return !queue_results.empty();
1594
1707
  });
1595
1708
 
1596
- for (int i = 0; i < (int) queue_results.size(); i++) {
1709
+ for (size_t i = 0; i < queue_results.size(); i++) {
1597
1710
  if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
1598
1711
  server_task_result_ptr res = std::move(queue_results[i]);
1599
1712
  queue_results.erase(queue_results.begin() + i);
@@ -1610,12 +1723,6 @@ struct server_response {
1610
1723
  server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
1611
1724
  while (true) {
1612
1725
  std::unique_lock<std::mutex> lock(mutex_results);
1613
- bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
1614
- return !queue_results.empty();
1615
- });
1616
- if (!cr_res) {
1617
- return nullptr;
1618
- }
1619
1726
 
1620
1727
  for (int i = 0; i < (int) queue_results.size(); i++) {
1621
1728
  if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
@@ -1624,6 +1731,11 @@ struct server_response {
1624
1731
  return res;
1625
1732
  }
1626
1733
  }
1734
+
1735
+ std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
1736
+ if (cr_res == std::cv_status::timeout) {
1737
+ return nullptr;
1738
+ }
1627
1739
  }
1628
1740
 
1629
1741
  // should never reach here
@@ -1688,6 +1800,8 @@ struct server_context {
1688
1800
  // Necessary similarity of prompt for slot selection
1689
1801
  float slot_prompt_similarity = 0.0f;
1690
1802
 
1803
+ common_chat_templates chat_templates;
1804
+
1691
1805
  ~server_context() {
1692
1806
  // Clear any sampling context
1693
1807
  for (server_slot & slot : slots) {
@@ -1728,13 +1842,16 @@ struct server_context {
1728
1842
  add_bos_token = llama_vocab_get_add_bos(vocab);
1729
1843
  has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
1730
1844
 
1731
- if (!params_base.speculative.model.empty()) {
1845
+ if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
1732
1846
  SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
1733
1847
 
1734
1848
  auto params_dft = params_base;
1735
1849
 
1736
1850
  params_dft.devices = params_base.speculative.devices;
1851
+ params_dft.hf_file = params_base.speculative.hf_file;
1852
+ params_dft.hf_repo = params_base.speculative.hf_repo;
1737
1853
  params_dft.model = params_base.speculative.model;
1854
+ params_dft.model_url = params_base.speculative.model_url;
1738
1855
  params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
1739
1856
  params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
1740
1857
  params_dft.n_parallel = 1;
@@ -1762,16 +1879,48 @@ struct server_context {
1762
1879
  // force F16 KV cache for the draft model for extra performance
1763
1880
  cparams_dft.type_k = GGML_TYPE_F16;
1764
1881
  cparams_dft.type_v = GGML_TYPE_F16;
1882
+
1883
+ // the context is not needed - we will create one for each slot
1884
+ llama_init_dft.context.reset();
1765
1885
  }
1766
1886
 
1887
+ if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
1888
+ LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
1889
+ chat_templates = common_chat_templates_from_model(model, "chatml");
1890
+ } else {
1891
+ chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
1892
+ }
1893
+ GGML_ASSERT(chat_templates.template_default.get() != nullptr);
1894
+
1767
1895
  return true;
1768
1896
  }
1769
1897
 
1770
- bool validate_builtin_chat_template() const {
1898
+ bool validate_builtin_chat_template(bool use_jinja) const {
1771
1899
  llama_chat_message chat[] = {{"user", "test"}};
1772
- const char * tmpl = llama_model_chat_template(model);
1773
- const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
1774
- return chat_res > 0;
1900
+
1901
+ if (use_jinja) {
1902
+ auto templates = common_chat_templates_from_model(model, "");
1903
+ common_chat_inputs inputs;
1904
+ inputs.messages = json::array({{
1905
+ {"role", "user"},
1906
+ {"content", "test"},
1907
+ }});
1908
+ GGML_ASSERT(templates.template_default);
1909
+ try {
1910
+ common_chat_params_init(*templates.template_default, inputs);
1911
+ if (templates.template_tool_use) {
1912
+ common_chat_params_init(*templates.template_tool_use, inputs);
1913
+ }
1914
+ return true;
1915
+ } catch (const std::exception & e) {
1916
+ SRV_ERR("failed to apply template: %s\n", e.what());
1917
+ return false;
1918
+ }
1919
+ } else {
1920
+ const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
1921
+ const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
1922
+ return chat_res > 0;
1923
+ }
1775
1924
  }
1776
1925
 
1777
1926
  void init() {
@@ -2210,11 +2359,11 @@ struct server_context {
2210
2359
  res->id_slot = slot.id;
2211
2360
 
2212
2361
  res->index = slot.index;
2213
- res->content = slot.generated_text;
2214
- res->tokens = slot.generated_tokens;
2362
+ res->content = std::move(slot.generated_text);
2363
+ res->tokens = std::move(slot.generated_tokens);
2215
2364
  res->timings = slot.get_timings();
2216
2365
  res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
2217
- res->response_fields = slot.params.response_fields;
2366
+ res->response_fields = std::move(slot.params.response_fields);
2218
2367
 
2219
2368
  res->truncated = slot.truncated;
2220
2369
  res->n_decoded = slot.n_decoded;
@@ -2225,12 +2374,12 @@ struct server_context {
2225
2374
  res->stop = slot.stop;
2226
2375
  res->post_sampling_probs = slot.params.post_sampling_probs;
2227
2376
 
2228
- res->verbose = slot.params.verbose;
2229
- res->stream = slot.params.stream;
2230
- res->oaicompat = slot.params.oaicompat;
2231
- res->oaicompat_model = slot.params.oaicompat_model;
2232
- res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2233
-
2377
+ res->verbose = slot.params.verbose;
2378
+ res->stream = slot.params.stream;
2379
+ res->oaicompat = slot.params.oaicompat;
2380
+ res->oaicompat_model = slot.params.oaicompat_model;
2381
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2382
+ res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
2234
2383
  // populate res.probs_output
2235
2384
  if (slot.params.sampling.n_probs > 0) {
2236
2385
  if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
@@ -2338,8 +2487,8 @@ struct server_context {
2338
2487
 
2339
2488
  server_task task(SERVER_TASK_TYPE_CANCEL);
2340
2489
  task.id_target = id_task;
2341
- cancel_tasks.push_back(task);
2342
2490
  queue_results.remove_waiting_task_id(id_task);
2491
+ cancel_tasks.push_back(task);
2343
2492
  }
2344
2493
  // push to beginning of the queue, so it has highest priority
2345
2494
  queue_tasks.post(cancel_tasks, true);
@@ -2708,6 +2857,10 @@ struct server_context {
2708
2857
  // track if given slot can be batched with slots already in the batch
2709
2858
  server_slot * slot_batched = nullptr;
2710
2859
 
2860
+ auto accept_special_token = [&](server_slot & slot, llama_token token) {
2861
+ return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
2862
+ };
2863
+
2711
2864
  // frist, add sampled tokens from any ongoing sequences
2712
2865
  for (auto & slot : slots) {
2713
2866
  if (slot.state != SLOT_STATE_GENERATING) {
@@ -3071,7 +3224,7 @@ struct server_context {
3071
3224
 
3072
3225
  completion_token_output result;
3073
3226
  result.tok = id;
3074
- result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
3227
+ result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3075
3228
  result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
3076
3229
 
3077
3230
  if (slot.params.sampling.n_probs > 0) {
@@ -3160,7 +3313,7 @@ struct server_context {
3160
3313
  completion_token_output result;
3161
3314
 
3162
3315
  result.tok = ids[i];
3163
- result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
3316
+ result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3164
3317
  result.prob = 1.0f; // set later
3165
3318
 
3166
3319
  // TODO: set result.probs
@@ -3200,6 +3353,8 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
3200
3353
  return;
3201
3354
  }
3202
3355
 
3356
+ // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
3357
+
3203
3358
  LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
3204
3359
 
3205
3360
  LOG_DBG("request: %s\n", req.body.c_str());
@@ -3286,9 +3441,13 @@ int main(int argc, char ** argv) {
3286
3441
  message = "Unknown Exception";
3287
3442
  }
3288
3443
 
3289
- json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
3290
- LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
3291
- res_error(res, formatted_error);
3444
+ try {
3445
+ json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
3446
+ LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
3447
+ res_error(res, formatted_error);
3448
+ } catch (const std::exception & e) {
3449
+ LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
3450
+ }
3292
3451
  });
3293
3452
 
3294
3453
  svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
@@ -3510,11 +3669,11 @@ int main(int argc, char ** argv) {
3510
3669
  {"value", (uint64_t) res_metrics->kv_cache_tokens_count}
3511
3670
  },{
3512
3671
  {"name", "requests_processing"},
3513
- {"help", "Number of request processing."},
3672
+ {"help", "Number of requests processing."},
3514
3673
  {"value", (uint64_t) res_metrics->n_processing_slots}
3515
3674
  },{
3516
3675
  {"name", "requests_deferred"},
3517
- {"help", "Number of request deferred."},
3676
+ {"help", "Number of requests deferred."},
3518
3677
  {"value", (uint64_t) res_metrics->n_tasks_deferred}
3519
3678
  }}}
3520
3679
  };
@@ -3656,9 +3815,14 @@ int main(int argc, char ** argv) {
3656
3815
  { "default_generation_settings", ctx_server.default_generation_settings_for_props },
3657
3816
  { "total_slots", ctx_server.params_base.n_parallel },
3658
3817
  { "model_path", ctx_server.params_base.model },
3659
- { "chat_template", common_get_builtin_chat_template(ctx_server.model) },
3818
+ { "chat_template", ctx_server.chat_templates.template_default->source() },
3819
+ { "bos_token", ctx_server.chat_templates.template_default->bos_token() },
3820
+ { "eos_token", ctx_server.chat_templates.template_default->eos_token() },
3660
3821
  { "build_info", build_info },
3661
3822
  };
3823
+ if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
3824
+ data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
3825
+ }
3662
3826
 
3663
3827
  res_ok(res, data);
3664
3828
  };
@@ -3695,7 +3859,9 @@ int main(int argc, char ** argv) {
3695
3859
  std::vector<server_task> tasks;
3696
3860
 
3697
3861
  try {
3698
- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
3862
+ const auto & prompt = data.at("prompt");
3863
+ LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
3864
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
3699
3865
  tasks.reserve(tokenized_prompts.size());
3700
3866
  for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3701
3867
  server_task task = server_task(type);
@@ -3711,8 +3877,8 @@ int main(int argc, char ** argv) {
3711
3877
  task.id_selected_slot = json_value(data, "id_slot", -1);
3712
3878
 
3713
3879
  // OAI-compat
3714
- task.params.oaicompat = oaicompat;
3715
- task.params.oaicompat_cmpl_id = completion_id;
3880
+ task.params.oaicompat = oaicompat;
3881
+ task.params.oaicompat_cmpl_id = completion_id;
3716
3882
  // oaicompat_model is already populated by params_from_json_cmpl
3717
3883
 
3718
3884
  tasks.push_back(task);
@@ -3881,12 +4047,15 @@ int main(int argc, char ** argv) {
3881
4047
  };
3882
4048
 
3883
4049
  const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
4050
+ LOG_DBG("request: %s\n", req.body.c_str());
3884
4051
  if (ctx_server.params_base.embedding) {
3885
4052
  res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3886
4053
  return;
3887
4054
  }
3888
4055
 
3889
- json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
4056
+ auto body = json::parse(req.body);
4057
+ json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
4058
+
3890
4059
  return handle_completions_impl(
3891
4060
  SERVER_TASK_TYPE_COMPLETION,
3892
4061
  data,
@@ -3895,6 +4064,13 @@ int main(int argc, char ** argv) {
3895
4064
  OAICOMPAT_TYPE_CHAT);
3896
4065
  };
3897
4066
 
4067
+ // same with handle_chat_completions, but without inference part
4068
+ const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
4069
+ auto body = json::parse(req.body);
4070
+ json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates);
4071
+ res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
4072
+ };
4073
+
3898
4074
  const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
3899
4075
  json models = {
3900
4076
  {"object", "list"},
@@ -4229,6 +4405,7 @@ int main(int argc, char ** argv) {
4229
4405
  svr->Post("/v1/reranking", handle_rerank);
4230
4406
  svr->Post("/tokenize", handle_tokenize);
4231
4407
  svr->Post("/detokenize", handle_detokenize);
4408
+ svr->Post("/apply-template", handle_apply_template);
4232
4409
  // LoRA adapters hotswap
4233
4410
  svr->Get ("/lora-adapters", handle_lora_adapters_list);
4234
4411
  svr->Post("/lora-adapters", handle_lora_adapters_apply);
@@ -4294,24 +4471,18 @@ int main(int argc, char ** argv) {
4294
4471
 
4295
4472
  LOG_INF("%s: model loaded\n", __func__);
4296
4473
 
4297
- // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
4298
- if (params.chat_template.empty()) {
4299
- if (!ctx_server.validate_builtin_chat_template()) {
4300
- LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
4301
- params.chat_template = "chatml";
4302
- }
4303
- }
4304
-
4305
4474
  // print sample chat example to make it clear which template is used
4306
4475
  LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
4307
- params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
4308
- common_chat_format_example(ctx_server.model, params.chat_template).c_str());
4476
+ ctx_server.chat_templates.template_default->source().c_str(),
4477
+ common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
4309
4478
 
4310
- ctx_server.queue_tasks.on_new_task(std::bind(
4311
- &server_context::process_single_task, &ctx_server, std::placeholders::_1));
4479
+ ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
4480
+ ctx_server.process_single_task(task);
4481
+ });
4312
4482
 
4313
- ctx_server.queue_tasks.on_update_slots(std::bind(
4314
- &server_context::update_slots, &ctx_server));
4483
+ ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
4484
+ ctx_server.update_slots();
4485
+ });
4315
4486
 
4316
4487
  shutdown_handler = [&](int) {
4317
4488
  ctx_server.queue_tasks.terminate();