@fugood/llama.node 1.0.1 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.0.1",
4
+ "version": "1.0.2",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
@@ -70,19 +70,19 @@
70
70
  "CMakeLists.txt"
71
71
  ],
72
72
  "optionalDependencies": {
73
- "@fugood/node-llama-linux-x64": "1.0.1",
74
- "@fugood/node-llama-linux-x64-vulkan": "1.0.1",
75
- "@fugood/node-llama-linux-x64-cuda": "1.0.1",
76
- "@fugood/node-llama-linux-arm64": "1.0.1",
77
- "@fugood/node-llama-linux-arm64-vulkan": "1.0.1",
78
- "@fugood/node-llama-linux-arm64-cuda": "1.0.1",
79
- "@fugood/node-llama-win32-x64": "1.0.1",
80
- "@fugood/node-llama-win32-x64-vulkan": "1.0.1",
81
- "@fugood/node-llama-win32-x64-cuda": "1.0.1",
82
- "@fugood/node-llama-win32-arm64": "1.0.1",
83
- "@fugood/node-llama-win32-arm64-vulkan": "1.0.1",
84
- "@fugood/node-llama-darwin-x64": "1.0.1",
85
- "@fugood/node-llama-darwin-arm64": "1.0.1"
73
+ "@fugood/node-llama-linux-x64": "1.0.2",
74
+ "@fugood/node-llama-linux-x64-vulkan": "1.0.2",
75
+ "@fugood/node-llama-linux-x64-cuda": "1.0.2",
76
+ "@fugood/node-llama-linux-arm64": "1.0.2",
77
+ "@fugood/node-llama-linux-arm64-vulkan": "1.0.2",
78
+ "@fugood/node-llama-linux-arm64-cuda": "1.0.2",
79
+ "@fugood/node-llama-win32-x64": "1.0.2",
80
+ "@fugood/node-llama-win32-x64-vulkan": "1.0.2",
81
+ "@fugood/node-llama-win32-x64-cuda": "1.0.2",
82
+ "@fugood/node-llama-win32-arm64": "1.0.2",
83
+ "@fugood/node-llama-win32-arm64-vulkan": "1.0.2",
84
+ "@fugood/node-llama-darwin-x64": "1.0.2",
85
+ "@fugood/node-llama-darwin-arm64": "1.0.2"
86
86
  },
87
87
  "devDependencies": {
88
88
  "@babel/preset-env": "^7.24.4",
@@ -1,5 +1,5 @@
1
1
  diff --git a/src/llama.cpp/common/chat.cpp b/src/llama.cpp/common/chat.cpp
2
- index 7d9aaeb1..a7b68d4a 100644
2
+ index 114dbfcc..6771bd43 100644
3
3
  --- a/src/llama.cpp/common/chat.cpp
4
4
  +++ b/src/llama.cpp/common/chat.cpp
5
5
  @@ -6,9 +6,6 @@
@@ -12,7 +12,7 @@ index 7d9aaeb1..a7b68d4a 100644
12
12
  #include <cstdio>
13
13
  #include <exception>
14
14
  #include <iostream>
15
- @@ -121,14 +118,6 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
15
+ @@ -123,14 +120,6 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
16
16
  return diffs;
17
17
  }
18
18
 
@@ -27,13 +27,13 @@ index 7d9aaeb1..a7b68d4a 100644
27
27
  struct templates_params {
28
28
  json messages;
29
29
  json tools;
30
- diff --git a/src/llama.cpp/common/chat.h b/src/llama.cpp/common/chat.h
31
- index 9f59e6b0..9b7fe724 100644
30
+ diff --git a/common/chat.h b/common/chat.h
31
+ index ca807c14..56649863 100644
32
32
  --- a/src/llama.cpp/common/chat.h
33
33
  +++ b/src/llama.cpp/common/chat.h
34
- @@ -8,7 +8,16 @@
35
- #include <string>
34
+ @@ -9,7 +9,16 @@
36
35
  #include <vector>
36
+ #include <map>
37
37
 
38
38
  -struct common_chat_templates;
39
39
  +#include <minja/chat-template.hpp>
@@ -62,10 +62,10 @@ index e4e71ad1..091ddda4 100644
62
62
  mparams.split_mode = params.split_mode;
63
63
  mparams.tensor_split = params.tensor_split;
64
64
  diff --git a/src/llama.cpp/common/common.h b/src/llama.cpp/common/common.h
65
- index e08a59ea..d120b67d 100644
65
+ index 8922090e..3c2d1a6a 100644
66
66
  --- a/src/llama.cpp/common/common.h
67
67
  +++ b/src/llama.cpp/common/common.h
68
- @@ -223,6 +223,7 @@ enum common_reasoning_format {
68
+ @@ -224,6 +224,7 @@ enum common_reasoning_format {
69
69
  };
70
70
 
71
71
  struct common_params {
@@ -74,7 +74,7 @@ index e08a59ea..d120b67d 100644
74
74
  int32_t n_ctx = 4096; // context size
75
75
  int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
76
76
  diff --git a/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt b/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
77
- index 71b1d67b..093cd6f9 100644
77
+ index 671fad4d..93fc3cd7 100644
78
78
  --- a/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
79
79
  +++ b/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
80
80
  @@ -104,7 +104,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
@@ -87,10 +87,10 @@ index 71b1d67b..093cd6f9 100644
87
87
  check_cxx_compiler_flag(-mfp16-format=ieee GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E)
88
88
  if (NOT "${GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
89
89
  diff --git a/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt b/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt
90
- index 39f022f3..7ae9047e 100644
90
+ index b97e7bf9..c3eb9519 100644
91
91
  --- a/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt
92
92
  +++ b/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt
93
- @@ -110,7 +110,7 @@ if (Vulkan_FOUND)
93
+ @@ -111,7 +111,7 @@ if (Vulkan_FOUND)
94
94
  endif()
95
95
 
96
96
  # Set up toolchain for host compilation whether cross-compiling or not
@@ -99,7 +99,7 @@ index 39f022f3..7ae9047e 100644
99
99
  if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
100
100
  set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
101
101
  else()
102
- @@ -130,7 +130,7 @@ if (Vulkan_FOUND)
102
+ @@ -131,7 +131,7 @@ if (Vulkan_FOUND)
103
103
 
104
104
  include(ExternalProject)
105
105
 
@@ -2794,6 +2794,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2794
2794
  params.ssl_file_cert = value;
2795
2795
  }
2796
2796
  ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
2797
+ add_opt(common_arg(
2798
+ {"--chat-template-kwargs"}, "STRING",
2799
+ string_format("sets additional params for the json template parser"),
2800
+ [](common_params & params, const std::string & value) {
2801
+ auto parsed = json::parse(value);
2802
+ for (const auto & item : parsed.items()) {
2803
+ params.default_template_kwargs[item.key()] = item.value().dump();
2804
+ }
2805
+ }
2806
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_CHAT_TEMPLATE_KWARGS"));
2797
2807
  add_opt(common_arg(
2798
2808
  {"-to", "--timeout"}, "N",
2799
2809
  string_format("server read/write timeout in seconds (default: %d)", params.timeout_read),
@@ -14,6 +14,8 @@
14
14
  #include <string>
15
15
  #include <vector>
16
16
 
17
+ using json = nlohmann::ordered_json;
18
+
17
19
  static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
18
20
  auto time = std::chrono::system_clock::to_time_t(now);
19
21
  auto local_time = *std::localtime(&time);
@@ -129,6 +131,7 @@ struct templates_params {
129
131
  bool add_generation_prompt = true;
130
132
  bool enable_thinking = true;
131
133
  std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
134
+ json extra_context;
132
135
  };
133
136
 
134
137
  common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -709,16 +712,23 @@ static void foreach_function(const json & tools, const std::function<void(const
709
712
 
710
713
  static std::string apply(
711
714
  const common_chat_template & tmpl,
712
- const nlohmann::ordered_json & messages,
713
- const nlohmann::ordered_json & tools,
714
- bool add_generation_prompt,
715
- const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
715
+ const struct templates_params & inputs,
716
+ const std::optional<json> & messages_override = std::nullopt,
717
+ const std::optional<json> & tools_override = std::nullopt,
718
+ const std::optional<json> & additional_context = std::nullopt)
716
719
  {
717
720
  minja::chat_template_inputs tmpl_inputs;
718
- tmpl_inputs.messages = messages;
719
- tmpl_inputs.tools = tools;
720
- tmpl_inputs.add_generation_prompt = add_generation_prompt;
721
- tmpl_inputs.extra_context = extra_context;
721
+ tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
722
+ if (tools_override) {
723
+ tmpl_inputs.tools = *tools_override;
724
+ } else {
725
+ tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
726
+ }
727
+ tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
728
+ tmpl_inputs.extra_context = inputs.extra_context;
729
+ if (additional_context) {
730
+ tmpl_inputs.extra_context.merge_patch(*additional_context);
731
+ }
722
732
  // TODO: add flag to control date/time, if only for testing purposes.
723
733
  // tmpl_inputs.now = std::chrono::system_clock::now();
724
734
 
@@ -817,7 +827,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
817
827
  inputs.messages,
818
828
  "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
819
829
 
820
- data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
830
+ data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
821
831
  data.format = COMMON_CHAT_FORMAT_GENERIC;
822
832
  return data;
823
833
  }
@@ -893,7 +903,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
893
903
  data.preserved_tokens = {
894
904
  "[TOOL_CALLS]",
895
905
  };
896
- data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
906
+ data.prompt = apply(tmpl, inputs);
897
907
  data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
898
908
  return data;
899
909
  }
@@ -923,7 +933,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
923
933
  adjusted_messages.push_back(msg);
924
934
  }
925
935
  }
926
- data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
936
+ data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
927
937
  data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
928
938
  if (string_ends_with(data.prompt, "<|START_THINKING|>")) {
929
939
  if (!inputs.enable_thinking) {
@@ -1111,7 +1121,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
1111
1121
  } else {
1112
1122
  data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1113
1123
  }
1114
- data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
1124
+ data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
1115
1125
  {"date_string", format_time(inputs.now, "%d %b %Y")},
1116
1126
  {"tools_in_user_message", false},
1117
1127
  {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
@@ -1176,7 +1186,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
1176
1186
 
1177
1187
  static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
1178
1188
  common_chat_params data;
1179
- auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1189
+ auto prompt = apply(tmpl, inputs);
1180
1190
 
1181
1191
  // Hacks to fix the official (broken) prompt.
1182
1192
  // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
@@ -1271,7 +1281,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
1271
1281
  static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1272
1282
  LOG_DBG("%s\n", __func__);
1273
1283
  common_chat_params data;
1274
- data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
1284
+ data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json {
1275
1285
  {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
1276
1286
  {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
1277
1287
  });
@@ -1327,7 +1337,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
1327
1337
  // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
1328
1338
  // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
1329
1339
  common_chat_params data;
1330
- data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1340
+ data.prompt = apply(tmpl, inputs);
1331
1341
  data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
1332
1342
  if (inputs.tools.is_array() && !inputs.tools.empty()) {
1333
1343
  data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1454,7 +1464,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
1454
1464
  data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1455
1465
  }
1456
1466
 
1457
- data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1467
+ data.prompt = apply(tmpl, inputs);
1458
1468
  // TODO: if (has_raw_python)
1459
1469
  return data;
1460
1470
  }
@@ -1487,14 +1497,15 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
1487
1497
  static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
1488
1498
  common_chat_params data;
1489
1499
 
1490
- json additional_context = {
1500
+ json extra_context = json {
1491
1501
  {"enable_thinking", inputs.enable_thinking},
1492
1502
  };
1503
+ extra_context.update(inputs.extra_context);
1493
1504
 
1494
- data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context);
1505
+ data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context);
1495
1506
  data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
1496
1507
  if (string_ends_with(data.prompt, "<think>\n")) {
1497
- if (!inputs.enable_thinking) {
1508
+ if (!extra_context["enable_thinking"]) {
1498
1509
  data.prompt += "</think>";
1499
1510
  } else {
1500
1511
  data.thinking_forced_open = true;
@@ -1680,7 +1691,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
1680
1691
 
1681
1692
  static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
1682
1693
  common_chat_params data;
1683
- data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1694
+ data.prompt = apply(tmpl, inputs);
1684
1695
  data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1685
1696
  data.grammar_lazy = false;
1686
1697
  if (!inputs.json_schema.is_null()) {
@@ -1711,6 +1722,12 @@ static common_chat_params common_chat_templates_apply_jinja(
1711
1722
  params.enable_thinking = inputs.enable_thinking;
1712
1723
  params.grammar = inputs.grammar;
1713
1724
  params.now = inputs.now;
1725
+
1726
+ params.extra_context = json::object();
1727
+ for (auto el : inputs.chat_template_kwargs) {
1728
+ params.extra_context[el.first] = json::parse(el.second);
1729
+ }
1730
+
1714
1731
  if (!inputs.json_schema.empty()) {
1715
1732
  params.json_schema = json::parse(inputs.json_schema);
1716
1733
  }
@@ -7,6 +7,7 @@
7
7
  #include <chrono>
8
8
  #include <string>
9
9
  #include <vector>
10
+ #include <map>
10
11
 
11
12
  #include <minja/chat-template.hpp>
12
13
  #include <minja/minja.hpp>
@@ -134,6 +135,7 @@ struct common_chat_templates_inputs {
134
135
  common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
135
136
  bool enable_thinking = true;
136
137
  std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
138
+ std::map<std::string, std::string> chat_template_kwargs;
137
139
  };
138
140
 
139
141
  struct common_chat_params {
@@ -8,6 +8,7 @@
8
8
  #include <string>
9
9
  #include <string_view>
10
10
  #include <vector>
11
+ #include <map>
11
12
  #include <sstream>
12
13
 
13
14
  #ifdef _WIN32
@@ -382,6 +383,8 @@ struct common_params {
382
383
  std::string ssl_file_key = ""; // NOLINT
383
384
  std::string ssl_file_cert = ""; // NOLINT
384
385
 
386
+ std::map<std::string, std::string> default_template_kwargs;
387
+
385
388
  // "advanced" endpoints are disabled by default for better security
386
389
  bool webui = true;
387
390
  bool endpoint_slots = false;
@@ -339,7 +339,7 @@ extern "C" {
339
339
  typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
340
340
 
341
341
  // Compare the output of two backends
342
- GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
342
+ GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
343
343
 
344
344
  // Tensor initialization
345
345
  GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
@@ -134,6 +134,7 @@ extern "C" {
134
134
 
135
135
  GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
136
136
 
137
+ GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
137
138
  GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
138
139
  GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
139
140
  GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
@@ -470,6 +470,7 @@ extern "C" {
470
470
  GGML_OP_TRANSPOSE,
471
471
  GGML_OP_GET_ROWS,
472
472
  GGML_OP_GET_ROWS_BACK,
473
+ GGML_OP_SET_ROWS,
473
474
  GGML_OP_DIAG,
474
475
  GGML_OP_DIAG_MASK_INF,
475
476
  GGML_OP_DIAG_MASK_ZERO,
@@ -519,6 +520,8 @@ extern "C" {
519
520
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
520
521
  GGML_OP_OPT_STEP_ADAMW,
521
522
 
523
+ GGML_OP_GLU,
524
+
522
525
  GGML_OP_COUNT,
523
526
  };
524
527
 
@@ -542,6 +545,14 @@ extern "C" {
542
545
  GGML_UNARY_OP_COUNT,
543
546
  };
544
547
 
548
+ enum ggml_glu_op {
549
+ GGML_GLU_OP_REGLU,
550
+ GGML_GLU_OP_GEGLU,
551
+ GGML_GLU_OP_SWIGLU,
552
+
553
+ GGML_GLU_OP_COUNT,
554
+ };
555
+
545
556
  enum ggml_object_type {
546
557
  GGML_OBJECT_TYPE_TENSOR,
547
558
  GGML_OBJECT_TYPE_GRAPH,
@@ -657,6 +668,7 @@ extern "C" {
657
668
  GGML_API const char * ggml_op_symbol(enum ggml_op op);
658
669
 
659
670
  GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
671
+ GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
660
672
  GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
661
673
 
662
674
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@@ -687,6 +699,9 @@ extern "C" {
687
699
  // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
688
700
  GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
689
701
 
702
+ // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
703
+ GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
704
+
690
705
  GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
691
706
  GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
692
707
 
@@ -758,6 +773,7 @@ extern "C" {
758
773
  GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
759
774
 
760
775
  GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
776
+ GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
761
777
 
762
778
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
763
779
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@@ -1086,6 +1102,63 @@ extern "C" {
1086
1102
  struct ggml_context * ctx,
1087
1103
  struct ggml_tensor * a);
1088
1104
 
1105
+ // gated linear unit ops
1106
+ // A: n columns, r rows,
1107
+ // result is n / 2 columns, r rows,
1108
+ // expects gate in second half of row, unless swapped is true
1109
+ GGML_API struct ggml_tensor * ggml_glu(
1110
+ struct ggml_context * ctx,
1111
+ struct ggml_tensor * a,
1112
+ enum ggml_glu_op op,
1113
+ bool swapped);
1114
+
1115
+ GGML_API struct ggml_tensor * ggml_reglu(
1116
+ struct ggml_context * ctx,
1117
+ struct ggml_tensor * a);
1118
+
1119
+ GGML_API struct ggml_tensor * ggml_reglu_swapped(
1120
+ struct ggml_context * ctx,
1121
+ struct ggml_tensor * a);
1122
+
1123
+ GGML_API struct ggml_tensor * ggml_geglu(
1124
+ struct ggml_context * ctx,
1125
+ struct ggml_tensor * a);
1126
+
1127
+ GGML_API struct ggml_tensor * ggml_geglu_swapped(
1128
+ struct ggml_context * ctx,
1129
+ struct ggml_tensor * a);
1130
+
1131
+ GGML_API struct ggml_tensor * ggml_swiglu(
1132
+ struct ggml_context * ctx,
1133
+ struct ggml_tensor * a);
1134
+
1135
+ GGML_API struct ggml_tensor * ggml_swiglu_swapped(
1136
+ struct ggml_context * ctx,
1137
+ struct ggml_tensor * a);
1138
+
1139
+ // A: n columns, r rows,
1140
+ // B: n columns, r rows,
1141
+ GGML_API struct ggml_tensor * ggml_glu_split(
1142
+ struct ggml_context * ctx,
1143
+ struct ggml_tensor * a,
1144
+ struct ggml_tensor * b,
1145
+ enum ggml_glu_op op);
1146
+
1147
+ GGML_API struct ggml_tensor * ggml_reglu_split(
1148
+ struct ggml_context * ctx,
1149
+ struct ggml_tensor * a,
1150
+ struct ggml_tensor * b);
1151
+
1152
+ GGML_API struct ggml_tensor * ggml_geglu_split(
1153
+ struct ggml_context * ctx,
1154
+ struct ggml_tensor * a,
1155
+ struct ggml_tensor * b);
1156
+
1157
+ GGML_API struct ggml_tensor * ggml_swiglu_split(
1158
+ struct ggml_context * ctx,
1159
+ struct ggml_tensor * a,
1160
+ struct ggml_tensor * b);
1161
+
1089
1162
  // normalize along rows
1090
1163
  GGML_API struct ggml_tensor * ggml_norm(
1091
1164
  struct ggml_context * ctx,
@@ -1375,6 +1448,23 @@ extern "C" {
1375
1448
  struct ggml_tensor * b, // row indices
1376
1449
  struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
1377
1450
 
1451
+ // a TD [n_embd, ne1, ne2, ne3]
1452
+ // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
1453
+ // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
1454
+ //
1455
+ // undefined behavior if destination rows overlap
1456
+ //
1457
+ // broadcast:
1458
+ // ne2 % ne11 == 0
1459
+ // ne3 % ne12 == 0
1460
+ //
1461
+ // return view(a)
1462
+ GGML_API struct ggml_tensor * ggml_set_rows(
1463
+ struct ggml_context * ctx,
1464
+ struct ggml_tensor * a, // destination
1465
+ struct ggml_tensor * b, // source
1466
+ struct ggml_tensor * c); // row indices
1467
+
1378
1468
  GGML_API struct ggml_tensor * ggml_diag(
1379
1469
  struct ggml_context * ctx,
1380
1470
  struct ggml_tensor * a);
@@ -195,6 +195,7 @@ typedef pthread_t ggml_thread_t;
195
195
 
196
196
  static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
197
197
  [GGML_TYPE_F32] = {
198
+ .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
198
199
  .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
199
200
  .vec_dot_type = GGML_TYPE_F32,
200
201
  .nrows = 1,
@@ -1817,6 +1818,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1817
1818
  {
1818
1819
  ggml_compute_forward_get_rows_back(params, tensor);
1819
1820
  } break;
1821
+ case GGML_OP_SET_ROWS:
1822
+ {
1823
+ ggml_compute_forward_set_rows(params, tensor);
1824
+ } break;
1820
1825
  case GGML_OP_DIAG:
1821
1826
  {
1822
1827
  ggml_compute_forward_diag(params, tensor);
@@ -1944,6 +1949,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1944
1949
  {
1945
1950
  ggml_compute_forward_unary(params, tensor);
1946
1951
  } break;
1952
+ case GGML_OP_GLU:
1953
+ {
1954
+ ggml_compute_forward_glu(params, tensor);
1955
+ } break;
1947
1956
  case GGML_OP_GET_REL_POS:
1948
1957
  {
1949
1958
  ggml_compute_forward_get_rel_pos(params, tensor);
@@ -2154,6 +2163,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2154
2163
  GGML_ABORT("fatal error");
2155
2164
  }
2156
2165
  break;
2166
+ case GGML_OP_GLU:
2167
+ switch (ggml_get_glu_op(node)) {
2168
+ case GGML_GLU_OP_REGLU:
2169
+ case GGML_GLU_OP_GEGLU:
2170
+ case GGML_GLU_OP_SWIGLU:
2171
+ {
2172
+ n_tasks = n_threads;
2173
+ } break;
2174
+ default:
2175
+ GGML_ABORT("fatal error");
2176
+ }
2177
+ break;
2157
2178
  case GGML_OP_SILU_BACK:
2158
2179
  case GGML_OP_MUL:
2159
2180
  case GGML_OP_DIV:
@@ -2170,6 +2191,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2170
2191
  n_tasks = n_threads;
2171
2192
  } break;
2172
2193
  case GGML_OP_GET_ROWS:
2194
+ case GGML_OP_SET_ROWS:
2173
2195
  {
2174
2196
  // FIXME: get_rows can use additional threads, but the cost of launching additional threads
2175
2197
  // decreases performance with GPU offloading
@@ -3124,6 +3146,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
3124
3146
  return ggml_graph_compute(cgraph, &cplan);
3125
3147
  }
3126
3148
 
3149
+ void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
3150
+ memcpy(y, x, n * sizeof(float));
3151
+ }
3152
+
3127
3153
  void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
3128
3154
  int64_t i = 0;
3129
3155
  #if defined(__F16C__)
@@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
416
416
 
417
417
  switch (op->op) {
418
418
  case GGML_OP_CPY:
419
+ case GGML_OP_SET_ROWS:
419
420
  return
420
421
  op->type != GGML_TYPE_IQ3_XXS &&
421
422
  op->type != GGML_TYPE_IQ3_S &&