@fugood/llama.node 1.1.6 → 1.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. package/lib/binding.ts +4 -0
  2. package/lib/index.js +6 -1
  3. package/lib/index.ts +6 -0
  4. package/lib/version.js +5 -0
  5. package/lib/version.ts +2 -0
  6. package/package.json +14 -14
  7. package/scripts/llama.cpp.patch +9 -9
  8. package/src/LlamaCompletionWorker.cpp +73 -20
  9. package/src/LlamaCompletionWorker.h +8 -0
  10. package/src/llama.cpp/CMakeLists.txt +2 -0
  11. package/src/llama.cpp/common/arg.cpp +124 -40
  12. package/src/llama.cpp/common/chat-parser.cpp +9 -1
  13. package/src/llama.cpp/common/chat.cpp +312 -9
  14. package/src/llama.cpp/common/chat.h +4 -1
  15. package/src/llama.cpp/common/common.cpp +54 -0
  16. package/src/llama.cpp/common/common.h +41 -7
  17. package/src/llama.cpp/ggml/CMakeLists.txt +2 -0
  18. package/src/llama.cpp/ggml/include/ggml-opt.h +25 -6
  19. package/src/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  20. package/src/llama.cpp/ggml/include/ggml.h +28 -2
  21. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  22. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +1 -1
  23. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +1136 -1077
  24. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +14 -0
  25. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
  26. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  27. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  28. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +63 -2
  29. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -1
  30. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +200 -51
  31. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +11 -0
  32. package/src/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  33. package/src/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  34. package/src/llama.cpp/include/llama.h +25 -0
  35. package/src/llama.cpp/src/llama-batch.cpp +1 -1
  36. package/src/llama.cpp/src/llama-chat.cpp +2 -4
  37. package/src/llama.cpp/src/llama-context.cpp +29 -17
  38. package/src/llama.cpp/src/llama-context.h +6 -5
  39. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +12 -6
  40. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +2 -2
  41. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +89 -69
  42. package/src/llama.cpp/src/llama-kv-cache-unified.h +2 -2
  43. package/src/llama.cpp/src/llama-memory-hybrid.cpp +6 -2
  44. package/src/llama.cpp/src/llama-memory-hybrid.h +2 -2
  45. package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -2
  46. package/src/llama.cpp/src/llama-memory-recurrent.h +2 -2
  47. package/src/llama.cpp/src/llama-memory.h +2 -2
  48. package/src/llama.cpp/src/llama-model.cpp +1 -0
  49. package/src/llama.cpp/src/llama-model.h +1 -0
  50. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  51. package/src/llama.cpp/src/llama-vocab.cpp +2 -1
@@ -283,6 +283,7 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
283
283
  }
284
284
  if (!msg.reasoning_content.empty()) {
285
285
  jmsg["reasoning_content"] = msg.reasoning_content;
286
+ jmsg["thinking"] = msg.reasoning_content; // gpt-oss
286
287
  }
287
288
  if (!msg.tool_name.empty()) {
288
289
  jmsg["name"] = msg.tool_name;
@@ -459,11 +460,12 @@ std::string common_chat_format_single(
459
460
  return ss.str();
460
461
  }
461
462
 
462
- std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
463
+ std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja, const std::map<std::string, std::string> & chat_template_kwargs) {
463
464
  common_chat_templates_inputs inputs;
464
465
  inputs.use_jinja = use_jinja;
465
466
  inputs.add_bos = tmpls->add_bos;
466
467
  inputs.add_eos = tmpls->add_eos;
468
+ inputs.chat_template_kwargs = chat_template_kwargs;
467
469
  auto add_simple_msg = [&](auto role, auto content) {
468
470
  common_chat_msg msg;
469
471
  msg.role = role;
@@ -539,6 +541,17 @@ common_chat_templates_ptr common_chat_templates_init(
539
541
  default_template_src = CHATML_TEMPLATE_SRC;
540
542
  }
541
543
  }
544
+
545
+ // TODO @ngxson : this is a temporary hack to prevent chat template from throwing an error
546
+ // Ref: https://github.com/ggml-org/llama.cpp/pull/15230#issuecomment-3173959633
547
+ if (default_template_src.find("<|channel|>") != std::string::npos
548
+ // search for the error message and patch it
549
+ && default_template_src.find("in message.content or") != std::string::npos) {
550
+ string_replace_all(default_template_src,
551
+ "{%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}",
552
+ "{%- if false %}");
553
+ }
554
+
542
555
  std::string token_bos = bos_token_override;
543
556
  std::string token_eos = eos_token_override;
544
557
  bool add_bos = false;
@@ -593,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
593
606
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
594
607
  case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
595
608
  case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
609
+ case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
596
610
  case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
597
611
  default:
598
612
  throw std::runtime_error("Unknown chat format");
@@ -605,11 +619,25 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
605
619
  case COMMON_REASONING_FORMAT_AUTO: return "auto";
606
620
  case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
607
621
  case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
622
+ case COMMON_REASONING_FORMAT_GRANITE: return "granite";
608
623
  default:
609
624
  throw std::runtime_error("Unknown reasoning format");
610
625
  }
611
626
  }
612
627
 
628
+ common_reasoning_format common_reasoning_format_from_name(const std::string & format) {
629
+ if (format == "none") {
630
+ return COMMON_REASONING_FORMAT_NONE;
631
+ } else if (format == "auto") {
632
+ return COMMON_REASONING_FORMAT_AUTO;
633
+ } else if (format == "deepseek") {
634
+ return COMMON_REASONING_FORMAT_DEEPSEEK;
635
+ } else if (format == "deepseek-legacy") {
636
+ return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
637
+ }
638
+ throw std::runtime_error("Unknown reasoning format: " + format);
639
+ }
640
+
613
641
  static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
614
642
  std::string arguments;
615
643
  if (builder.is_partial()) {
@@ -1299,16 +1327,164 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
1299
1327
  data.prompt = prompt;
1300
1328
  data.format = COMMON_CHAT_FORMAT_GPT_OSS;
1301
1329
 
1302
- // TODO: support tool calls in GPT-OSS?
1330
+ // These special tokens are required to parse properly, so we include them
1331
+ // even if parse_tool_calls is false.
1332
+ data.preserved_tokens = {
1333
+ "<|channel|>",
1334
+ "<|constrain|>",
1335
+ "<|message|>",
1336
+ "<|start|>",
1337
+ "<|end|>",
1338
+ };
1339
+
1340
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
1341
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1342
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1343
+ // tool calls can appear in commentary or analysis channels
1344
+ auto channel = builder.add_rule("channel", "\"<|channel|>\" ( \"commentary\" | \"analysis\" )");
1345
+
1346
+ std::vector<std::string> tool_rules_recipient_in_role;
1347
+ std::vector<std::string> tool_rules_recipient_in_channel;
1348
+ foreach_function(inputs.tools, [&](const json & tool) {
1349
+ const auto & function = tool.at("function");
1350
+ std::string name = function.at("name");
1351
+ auto parameters = function.at("parameters");
1352
+ builder.resolve_refs(parameters);
1353
+
1354
+ tool_rules_recipient_in_role.push_back(
1355
+ builder.add_rule(name + "-call",
1356
+ "\"" + name + "\"" + channel + " \" <|constrain|>json\"? \"<|message|>\" " +
1357
+ builder.add_schema(name + "-args", parameters)
1358
+ )
1359
+ );
1360
+
1361
+ tool_rules_recipient_in_channel.push_back(
1362
+ builder.add_rule(name + "-call",
1363
+ "\"" + name + "\"" + " \" <|constrain|>json\"? \"<|message|>\" " +
1364
+ builder.add_schema(name + "-args", parameters)
1365
+ )
1366
+ );
1367
+ });
1368
+
1369
+ auto recipient_in_role = builder.add_rule("recipient_in_role",
1370
+ "\"<|start|>assistant\"? \" to=functions.\" ( " +
1371
+ string_join(tool_rules_recipient_in_role, " | ") + " )"
1372
+ );
1373
+
1374
+ auto recipient_in_channel = builder.add_rule("recipient_in_channel",
1375
+ channel + " \" to=functions.\" ( " +
1376
+ string_join(tool_rules_recipient_in_channel, " | ") + " )"
1377
+ );
1378
+
1379
+ builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);
1380
+
1381
+ // Trigger on tool calls that appear in the commentary channel
1382
+ data.grammar_triggers.push_back({
1383
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1384
+ "<\\|channel\\|>(commentary|analysis) to"
1385
+ });
1386
+
1387
+ // Trigger tool calls that appear in the role section, either at the
1388
+ // start or in the middle.
1389
+ data.grammar_triggers.push_back({
1390
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1391
+ "^ to"
1392
+ });
1393
+
1394
+ data.grammar_triggers.push_back({
1395
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1396
+ "<\\|start\\|>assistant to"
1397
+ });
1398
+ });
1399
+ }
1303
1400
 
1304
1401
  return data;
1305
1402
  }
1306
1403
  static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
1307
- // TODO @ngxson : this won't work with --special enabled, we should fix that
1308
- builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
1309
- if (!builder.syntax().parse_tool_calls) {
1310
- builder.add_content(builder.consume_rest());
1311
- return;
1404
+ static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
1405
+ static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
1406
+
1407
+ static const common_regex start_regex("<\\|start\\|>assistant");
1408
+ static const common_regex analysis_regex("<\\|channel\\|>analysis");
1409
+ static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
1410
+ static const common_regex preamble_regex("<\\|channel\\|>commentary");
1411
+ static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
1412
+ static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
1413
+
1414
+ auto consume_end = [&](bool include_end = false) {
1415
+ if (auto res = builder.try_find_literal("<|end|>")) {
1416
+ return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
1417
+ }
1418
+ return builder.consume_rest();
1419
+ };
1420
+
1421
+ auto handle_tool_call = [&](const std::string & name) {
1422
+ if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
1423
+ if (builder.syntax().parse_tool_calls) {
1424
+ if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
1425
+ throw common_chat_msg_partial_exception("incomplete tool call");
1426
+ }
1427
+ } else if (args->is_partial) {
1428
+ throw common_chat_msg_partial_exception("incomplete tool call");
1429
+ }
1430
+ }
1431
+ };
1432
+
1433
+ auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
1434
+ auto match = regex.search(input, 0, true);
1435
+ if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
1436
+ return match;
1437
+ }
1438
+ return std::nullopt;
1439
+ };
1440
+
1441
+ do {
1442
+ auto header_start_pos = builder.pos();
1443
+ auto content_start = builder.try_find_literal("<|message|>");
1444
+ if (!content_start) {
1445
+ throw common_chat_msg_partial_exception("incomplete header");
1446
+ }
1447
+
1448
+ auto header = content_start->prelude;
1449
+
1450
+ if (auto match = regex_match(tool_call1_regex, header)) {
1451
+ auto group = match->groups[1];
1452
+ auto name = header.substr(group.begin, group.end - group.begin);
1453
+ handle_tool_call(name);
1454
+ continue;
1455
+ }
1456
+
1457
+ if (auto match = regex_match(tool_call2_regex, header)) {
1458
+ auto group = match->groups[2];
1459
+ auto name = header.substr(group.begin, group.end - group.begin);
1460
+ handle_tool_call(name);
1461
+ continue;
1462
+ }
1463
+
1464
+ if (regex_match(analysis_regex, header)) {
1465
+ builder.move_to(header_start_pos);
1466
+ if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
1467
+ builder.add_content(consume_end(true));
1468
+ } else {
1469
+ builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
1470
+ }
1471
+ continue;
1472
+ }
1473
+
1474
+ if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
1475
+ builder.add_content(consume_end());
1476
+ continue;
1477
+ }
1478
+
1479
+ // Possibly a malformed message, attempt to recover by rolling
1480
+ // back to pick up the next <|start|>
1481
+ LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
1482
+ builder.move_to(header_start_pos);
1483
+ } while (builder.try_find_regex(start_regex, std::string::npos, false));
1484
+
1485
+ auto remaining = builder.consume_rest();
1486
+ if (!remaining.empty()) {
1487
+ LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
1312
1488
  }
1313
1489
  }
1314
1490
 
@@ -1721,6 +1897,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
1721
1897
  builder.add_content(builder.consume_rest());
1722
1898
  }
1723
1899
 
1900
+ static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
1901
+ common_chat_params data;
1902
+
1903
+ // Pass thinking context for Granite template
1904
+ json additional_context = {
1905
+ {"thinking", inputs.enable_thinking},
1906
+ };
1907
+
1908
+ data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
1909
+ data.format = COMMON_CHAT_FORMAT_GRANITE;
1910
+
1911
+ if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
1912
+ if (!inputs.enable_thinking) {
1913
+ data.prompt += "</think>";
1914
+ } else {
1915
+ data.thinking_forced_open = true;
1916
+ }
1917
+ }
1918
+
1919
+ if (!inputs.tools.is_null()) {
1920
+ // Granite uses <|tool_call|> followed by JSON list
1921
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1922
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1923
+ std::vector<std::string> tool_rules;
1924
+ foreach_function(inputs.tools, [&](const json & tool) {
1925
+ const auto & function = tool.at("function");
1926
+ std::string name = function.at("name");
1927
+ auto parameters = function.at("parameters");
1928
+ builder.resolve_refs(parameters);
1929
+ tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
1930
+ "-args", {
1931
+ {"type", "object"},
1932
+ {"properties", {
1933
+ {"name", {{"const", name}}},
1934
+ {"arguments", parameters},
1935
+ }},
1936
+ {"required", json::array({"name", "arguments"})},
1937
+ })));
1938
+ });
1939
+
1940
+ auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1941
+ auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
1942
+
1943
+ if (data.thinking_forced_open) {
1944
+ builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
1945
+ } else {
1946
+ builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
1947
+ }
1948
+
1949
+ data.grammar_triggers.push_back({
1950
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1951
+ "<|tool_call|>"
1952
+ });
1953
+
1954
+ data.preserved_tokens = {
1955
+ "<think>",
1956
+ "</think>",
1957
+ "<response>",
1958
+ "</response>",
1959
+ "<|tool_call|>",
1960
+ };
1961
+ });
1962
+ } else {
1963
+ // Handle thinking tags for non-tool responses
1964
+ if (data.thinking_forced_open && inputs.enable_thinking) {
1965
+ data.grammar_lazy = false;
1966
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1967
+ builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
1968
+ });
1969
+ data.preserved_tokens = {
1970
+ "<think>",
1971
+ "</think>",
1972
+ "<response>",
1973
+ "</response>",
1974
+ };
1975
+ }
1976
+ }
1977
+
1978
+ return data;
1979
+ }
1980
+
1981
+ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
1982
+ // Parse thinking tags
1983
+ builder.try_parse_reasoning("<think>", "</think>");
1984
+
1985
+ // Parse response tags using regex
1986
+ static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
1987
+ if (auto res = builder.try_find_regex(response_regex)) {
1988
+ // Extract the content between the tags (capture group 1)
1989
+ auto content = builder.str(res->groups[1]);
1990
+ builder.add_content(content);
1991
+ builder.move_to(res->groups[0].end);
1992
+ }
1993
+
1994
+ if (!builder.syntax().parse_tool_calls) {
1995
+ builder.add_content(builder.consume_rest());
1996
+ return;
1997
+ }
1998
+
1999
+ // Look for tool calls
2000
+ static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
2001
+ if (auto res = builder.try_find_regex(tool_call_regex)) {
2002
+ builder.move_to(res->groups[0].end);
2003
+
2004
+ // Expect JSON array of tool calls
2005
+ auto tool_calls_data = builder.consume_json();
2006
+ if (tool_calls_data.json.is_array()) {
2007
+ if (!builder.add_tool_calls(tool_calls_data.json)) {
2008
+ builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
2009
+ }
2010
+ } else {
2011
+ builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
2012
+ }
2013
+ } else {
2014
+ builder.add_content(builder.consume_rest());
2015
+ }
2016
+ }
2017
+
1724
2018
  static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
1725
2019
  common_chat_params data;
1726
2020
  data.prompt = apply(tmpl, inputs);
@@ -1754,8 +2048,8 @@ static common_chat_params common_chat_templates_apply_jinja(
1754
2048
  params.enable_thinking = inputs.enable_thinking;
1755
2049
  params.grammar = inputs.grammar;
1756
2050
  params.now = inputs.now;
1757
- params.add_bos = inputs.add_bos;
1758
- params.add_eos = inputs.add_eos;
2051
+ params.add_bos = tmpls->add_bos;
2052
+ params.add_eos = tmpls->add_eos;
1759
2053
 
1760
2054
  params.extra_context = json::object();
1761
2055
  for (auto el : inputs.chat_template_kwargs) {
@@ -1792,6 +2086,11 @@ static common_chat_params common_chat_templates_apply_jinja(
1792
2086
  return common_chat_params_init_command_r7b(tmpl, params);
1793
2087
  }
1794
2088
 
2089
+ // Granite (IBM) - detects thinking / tools support
2090
+ if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
2091
+ return common_chat_params_init_granite(tmpl, params);
2092
+ }
2093
+
1795
2094
  // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1796
2095
  if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
1797
2096
  return common_chat_params_init_hermes_2_pro(tmpl, params);
@@ -1852,6 +2151,7 @@ static common_chat_params common_chat_templates_apply_legacy(
1852
2151
  int alloc_size = 0;
1853
2152
  std::vector<llama_chat_message> chat;
1854
2153
  std::vector<std::string> contents;
2154
+
1855
2155
  for (const auto & msg : inputs.messages) {
1856
2156
  auto content = msg.content;
1857
2157
  for (const auto & part : msg.content_parts) {
@@ -1953,6 +2253,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
1953
2253
  case COMMON_CHAT_FORMAT_COMMAND_R7B:
1954
2254
  common_chat_parse_command_r7b(builder);
1955
2255
  break;
2256
+ case COMMON_CHAT_FORMAT_GRANITE:
2257
+ common_chat_parse_granite(builder);
2258
+ break;
1956
2259
  case COMMON_CHAT_FORMAT_GPT_OSS:
1957
2260
  common_chat_parse_gpt_oss(builder);
1958
2261
  break;
@@ -120,6 +120,7 @@ enum common_chat_format {
120
120
  COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
121
121
  COMMON_CHAT_FORMAT_HERMES_2_PRO,
122
122
  COMMON_CHAT_FORMAT_COMMAND_R7B,
123
+ COMMON_CHAT_FORMAT_GRANITE,
123
124
  COMMON_CHAT_FORMAT_GPT_OSS,
124
125
 
125
126
  COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
@@ -197,10 +198,12 @@ std::string common_chat_format_single(
197
198
  // Returns an example of formatted chat
198
199
  std::string common_chat_format_example(
199
200
  const struct common_chat_templates * tmpls,
200
- bool use_jinja);
201
+ bool use_jinja,
202
+ const std::map<std::string, std::string> & chat_template_kwargs);
201
203
 
202
204
  const char* common_chat_format_name(common_chat_format format);
203
205
  const char* common_reasoning_format_name(common_reasoning_format format);
206
+ common_reasoning_format common_reasoning_format_from_name(const std::string & format);
204
207
  common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
205
208
 
206
209
  common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
@@ -41,6 +41,7 @@
41
41
  #endif
42
42
  #include <locale>
43
43
  #include <windows.h>
44
+ #include <string.h>
44
45
  #include <fcntl.h>
45
46
  #include <io.h>
46
47
  #else
@@ -1566,3 +1567,56 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
1566
1567
 
1567
1568
  return result;
1568
1569
  }
1570
+
1571
+ ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
1572
+ ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
1573
+ const lr_opt & d = *(lr_opt *) userdata;
1574
+ result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
1575
+ result.sgd.wd = result.adamw.wd = d.wd;
1576
+ return result;
1577
+ }
1578
+
1579
+ // TODO make all command line args case-insensitive
1580
+ static inline bool eq_case_insensitive(char const* a, char const* b) {
1581
+ return !
1582
+ #if defined(_MSC_VER)
1583
+ _stricmp
1584
+ #else
1585
+ strcasecmp
1586
+ #endif // defined(_MSC_VER)
1587
+ (a, b);
1588
+ }
1589
+
1590
+ enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
1591
+ if (eq_case_insensitive("adamw", n)) {
1592
+ return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
1593
+ }
1594
+ if (eq_case_insensitive("sgd", n)) {
1595
+ return GGML_OPT_OPTIMIZER_TYPE_SGD;
1596
+ }
1597
+ return GGML_OPT_OPTIMIZER_TYPE_COUNT;
1598
+ }
1599
+
1600
+ // TODO simplify to use just log and exp
1601
+ static float const k_log_2 = std::log(2.f);
1602
+
1603
+ void lr_opt::init() {
1604
+ if (lr_min > 0 && lr_min < lr0) {
1605
+ float nhalf = std::log(lr0 / lr_min) / k_log_2;
1606
+ float e = epochs;
1607
+ if (decay_epochs > 0 && decay_epochs < e) {
1608
+ e = decay_epochs;
1609
+ } else {
1610
+ decay_epochs = e;
1611
+ }
1612
+ scale_epoch = nhalf / e;
1613
+ }
1614
+ }
1615
+
1616
+ float lr_opt::get_lr(float epoch) const {
1617
+ float r = lr_min <= 0 ? lr0 :
1618
+ epoch >= decay_epochs ? lr_min :
1619
+ lr0 * std::pow(0.5f, epoch * scale_epoch);
1620
+ LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
1621
+ return r;
1622
+ }
@@ -2,14 +2,17 @@
2
2
 
3
3
  #pragma once
4
4
 
5
- #include "llama-cpp.h"
6
-
7
5
  #include <set>
6
+ #include <sstream>
8
7
  #include <string>
9
8
  #include <string_view>
10
9
  #include <vector>
11
10
  #include <map>
12
11
  #include <sstream>
12
+ #include <cmath>
13
+
14
+ #include "ggml-opt.h"
15
+ #include "llama-cpp.h"
13
16
 
14
17
  #ifdef _WIN32
15
18
  #define DIRECTORY_SEPARATOR '\\'
@@ -82,6 +85,7 @@ enum llama_example {
82
85
  LLAMA_EXAMPLE_PARALLEL,
83
86
  LLAMA_EXAMPLE_TTS,
84
87
  LLAMA_EXAMPLE_DIFFUSION,
88
+ LLAMA_EXAMPLE_FINETUNE,
85
89
 
86
90
  LLAMA_EXAMPLE_COUNT,
87
91
  };
@@ -202,6 +206,7 @@ struct common_params_speculative {
202
206
  float p_split = 0.1f; // speculative decoding split probability
203
207
  float p_min = 0.75f; // minimum speculative decoding probability (greedy)
204
208
  std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
209
+ std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
205
210
 
206
211
  ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
207
212
  ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
@@ -239,8 +244,28 @@ enum common_reasoning_format {
239
244
  COMMON_REASONING_FORMAT_AUTO,
240
245
  COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
241
246
  COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
247
+ COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
242
248
  };
243
249
 
250
+
251
+ struct lr_opt {
252
+ float lr0 = 1e-5; // learning rate at first epoch
253
+ float lr_min = -1;
254
+ float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
255
+ float scale_epoch = 0;
256
+ float wd = 0;
257
+ unsigned epochs = 2;
258
+
259
+ unsigned epoch; // set by optimizer outer (epochs) loop
260
+ // learning rate decay - constant LR per epoch only for now
261
+ float get_lr(float e) const;
262
+ float get_lr() const { return get_lr(epoch); }
263
+ // must call after arg parse, before get_lr
264
+ void init();
265
+ };
266
+
267
+ struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
268
+
244
269
  struct common_params {
245
270
  bool vocab_only = false;
246
271
  int32_t n_predict = -1; // new tokens to predict
@@ -376,6 +401,11 @@ struct common_params {
376
401
  bool no_mmproj = false; // explicitly disable multimodal model
377
402
  std::vector<std::string> image; // path to image file(s)
378
403
 
404
+ // finetune
405
+ struct lr_opt lr;
406
+ enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
407
+ float val_split = 0.05f; // fraction of the data used for the validation set
408
+
379
409
  // embedding
380
410
  bool embedding = false; // get only sentence embedding
381
411
  int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
@@ -384,11 +414,12 @@ struct common_params {
384
414
  std::string cls_sep = "\t"; // separator of classification sequences
385
415
 
386
416
  // server params
387
- int32_t port = 8080; // server listens on this network port
388
- int32_t timeout_read = 600; // http read timeout in seconds
389
- int32_t timeout_write = timeout_read; // http write timeout in seconds
390
- int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
391
- int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
417
+ int32_t port = 8080; // server listens on this network port
418
+ int32_t timeout_read = 600; // http read timeout in seconds
419
+ int32_t timeout_write = timeout_read; // http write timeout in seconds
420
+ int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
421
+ int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
422
+ int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
392
423
 
393
424
  std::string hostname = "127.0.0.1";
394
425
  std::string public_path = ""; // NOLINT
@@ -703,3 +734,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
703
734
  //
704
735
 
705
736
  ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
737
+
738
+ // "adamw" or "sgd" (case insensitive)
739
+ enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
@@ -176,6 +176,7 @@ option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM"
176
176
  option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
177
177
  option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
178
178
  option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
179
+ option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
179
180
  option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
180
181
  option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
181
182
  option(GGML_VULKAN "ggml: use Vulkan" OFF)
@@ -187,6 +188,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
187
188
  option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
188
189
  option(GGML_WEBGPU "ggml: use WebGPU" OFF)
189
190
  option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
191
+ option(GGML_ZDNN "ggml: use zDNN" OFF)
190
192
  option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
191
193
  option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
192
194
  option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)