@fugood/llama.node 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/lib/binding.ts +1 -0
  2. package/package.json +14 -14
  3. package/src/LlamaCompletionWorker.cpp +24 -4
  4. package/src/LlamaCompletionWorker.h +7 -1
  5. package/src/LlamaContext.cpp +2 -1
  6. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  7. package/src/llama.cpp/common/arg.cpp +37 -0
  8. package/src/llama.cpp/common/common.cpp +22 -6
  9. package/src/llama.cpp/common/common.h +14 -1
  10. package/src/llama.cpp/ggml/CMakeLists.txt +3 -0
  11. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  12. package/src/llama.cpp/ggml/include/ggml.h +13 -0
  13. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  14. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +23 -8
  16. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +39 -0
  18. package/src/llama.cpp/include/llama.h +13 -48
  19. package/src/llama.cpp/src/llama-arch.cpp +222 -15
  20. package/src/llama.cpp/src/llama-arch.h +16 -1
  21. package/src/llama.cpp/src/llama-batch.cpp +76 -70
  22. package/src/llama.cpp/src/llama-batch.h +24 -18
  23. package/src/llama.cpp/src/llama-chat.cpp +44 -1
  24. package/src/llama.cpp/src/llama-chat.h +2 -0
  25. package/src/llama.cpp/src/llama-context.cpp +134 -95
  26. package/src/llama.cpp/src/llama-context.h +13 -16
  27. package/src/llama.cpp/src/llama-cparams.h +3 -2
  28. package/src/llama.cpp/src/llama-graph.cpp +239 -154
  29. package/src/llama.cpp/src/llama-graph.h +162 -126
  30. package/src/llama.cpp/src/llama-hparams.cpp +45 -0
  31. package/src/llama.cpp/src/llama-hparams.h +11 -1
  32. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  34. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  35. package/src/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  36. package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  37. package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -9
  38. package/src/llama.cpp/src/llama-model.cpp +2309 -665
  39. package/src/llama.cpp/src/llama-model.h +18 -4
  40. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  41. package/src/llama.cpp/src/llama-vocab.cpp +368 -9
  42. package/src/llama.cpp/src/llama-vocab.h +43 -0
  43. package/src/llama.cpp/src/unicode.cpp +207 -0
  44. package/src/llama.cpp/src/unicode.h +2 -0
package/lib/binding.ts CHANGED
@@ -131,6 +131,7 @@ export type LlamaCompletionResult = {
131
131
  tokens_evaluated: number
132
132
  truncated: boolean
133
133
  context_full: boolean
134
+ audio_tokens?: Array<number>
134
135
  timings: {
135
136
  prompt_n: number
136
137
  prompt_ms: number
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.0.3",
4
+ "version": "1.0.5",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
@@ -70,19 +70,19 @@
70
70
  "CMakeLists.txt"
71
71
  ],
72
72
  "optionalDependencies": {
73
- "@fugood/node-llama-linux-x64": "1.0.3",
74
- "@fugood/node-llama-linux-x64-vulkan": "1.0.3",
75
- "@fugood/node-llama-linux-x64-cuda": "1.0.3",
76
- "@fugood/node-llama-linux-arm64": "1.0.3",
77
- "@fugood/node-llama-linux-arm64-vulkan": "1.0.3",
78
- "@fugood/node-llama-linux-arm64-cuda": "1.0.3",
79
- "@fugood/node-llama-win32-x64": "1.0.3",
80
- "@fugood/node-llama-win32-x64-vulkan": "1.0.3",
81
- "@fugood/node-llama-win32-x64-cuda": "1.0.3",
82
- "@fugood/node-llama-win32-arm64": "1.0.3",
83
- "@fugood/node-llama-win32-arm64-vulkan": "1.0.3",
84
- "@fugood/node-llama-darwin-x64": "1.0.3",
85
- "@fugood/node-llama-darwin-arm64": "1.0.3"
73
+ "@fugood/node-llama-linux-x64": "1.0.5",
74
+ "@fugood/node-llama-linux-x64-vulkan": "1.0.5",
75
+ "@fugood/node-llama-linux-x64-cuda": "1.0.5",
76
+ "@fugood/node-llama-linux-arm64": "1.0.5",
77
+ "@fugood/node-llama-linux-arm64-vulkan": "1.0.5",
78
+ "@fugood/node-llama-linux-arm64-cuda": "1.0.5",
79
+ "@fugood/node-llama-win32-x64": "1.0.5",
80
+ "@fugood/node-llama-win32-x64-vulkan": "1.0.5",
81
+ "@fugood/node-llama-win32-x64-cuda": "1.0.5",
82
+ "@fugood/node-llama-win32-arm64": "1.0.5",
83
+ "@fugood/node-llama-win32-arm64-vulkan": "1.0.5",
84
+ "@fugood/node-llama-darwin-x64": "1.0.5",
85
+ "@fugood/node-llama-darwin-arm64": "1.0.5"
86
86
  },
87
87
  "devDependencies": {
88
88
  "@babel/preset-env": "^7.24.4",
@@ -32,12 +32,15 @@ LlamaCompletionWorker::LlamaCompletionWorker(
32
32
  bool thinking_forced_open,
33
33
  std::string reasoning_format,
34
34
  const std::vector<std::string> &media_paths,
35
- const std::vector<llama_token> &guide_tokens)
35
+ const std::vector<llama_token> &guide_tokens,
36
+ bool has_vocoder,
37
+ tts_type tts_type_val)
36
38
  : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess),
37
39
  _params(params), _stop_words(stop_words), _chat_format(chat_format),
38
40
  _thinking_forced_open(thinking_forced_open),
39
41
  _reasoning_format(reasoning_format),
40
- _media_paths(media_paths), _guide_tokens(guide_tokens) {
42
+ _media_paths(media_paths), _guide_tokens(guide_tokens),
43
+ _has_vocoder(has_vocoder), _tts_type(tts_type_val) {
41
44
  if (!callback.IsEmpty()) {
42
45
  _tsfn = Napi::ThreadSafeFunction::New(info.Env(), callback,
43
46
  "LlamaCompletionCallback", 0, 1);
@@ -153,8 +156,7 @@ void LlamaCompletionWorker::Execute() {
153
156
  // For multimodal input, n_past might already be set
154
157
  // Only decode text tokens if we have any input left
155
158
  if (n_input > 0) {
156
- int ret =
157
- llama_decode(ctx, llama_batch_get_one(embd->data() + n_cur, n_input));
159
+ int ret = llama_decode(ctx, llama_batch_get_one(embd->data() + n_cur, n_input));
158
160
  if (ret < 0) {
159
161
  SetError("Failed to decode token, code: " + std::to_string(ret));
160
162
  break;
@@ -171,6 +173,15 @@ void LlamaCompletionWorker::Execute() {
171
173
  }
172
174
  _next_token_uses_guide_token = (new_token_id == 198);
173
175
  common_sampler_accept(sampling.get(), new_token_id, true);
176
+
177
+ // Collect audio tokens for TTS if vocoder is enabled
178
+ if (_has_vocoder) {
179
+ if ((_tts_type == OUTETTS_V0_2 || _tts_type == OUTETTS_V0_3) &&
180
+ (new_token_id >= 151672 && new_token_id <= 155772)) {
181
+ _result.audio_tokens.push_back(new_token_id);
182
+ }
183
+ }
184
+
174
185
  // prepare the next batch
175
186
  embd->emplace_back(new_token_id);
176
187
  auto token = common_token_to_piece(ctx, new_token_id);
@@ -291,6 +302,15 @@ void LlamaCompletionWorker::OnOK() {
291
302
  result.Set("content", Napi::String::New(env, content.c_str()));
292
303
  }
293
304
 
305
+ // Add audio_tokens if vocoder is enabled and we have audio tokens
306
+ if (_has_vocoder && !_result.audio_tokens.empty()) {
307
+ auto audio_tokens = Napi::Array::New(env, _result.audio_tokens.size());
308
+ for (size_t i = 0; i < _result.audio_tokens.size(); i++) {
309
+ audio_tokens.Set(i, Napi::Number::New(env, _result.audio_tokens[i]));
310
+ }
311
+ result.Set("audio_tokens", audio_tokens);
312
+ }
313
+
294
314
  auto ctx = _sess->context();
295
315
  const auto timings_token = llama_perf_context(ctx);
296
316
 
@@ -1,6 +1,7 @@
1
1
  #pragma once
2
2
 
3
3
  #include "common.hpp"
4
+ #include "tts_utils.h"
4
5
  #include <atomic>
5
6
  #include <functional>
6
7
  #include <napi.h>
@@ -23,7 +24,9 @@ public:
23
24
  bool thinking_forced_open,
24
25
  std::string reasoning_format,
25
26
  const std::vector<std::string> &media_paths = {},
26
- const std::vector<llama_token> &guide_tokens = {});
27
+ const std::vector<llama_token> &guide_tokens = {},
28
+ bool has_vocoder = false,
29
+ tts_type tts_type_val = UNKNOWN);
27
30
 
28
31
  ~LlamaCompletionWorker();
29
32
 
@@ -52,6 +55,8 @@ private:
52
55
  bool _stop = false;
53
56
  Napi::ThreadSafeFunction _tsfn;
54
57
  bool _next_token_uses_guide_token = true;
58
+ bool _has_vocoder;
59
+ tts_type _tts_type;
55
60
  struct {
56
61
  size_t tokens_evaluated = 0;
57
62
  size_t tokens_predicted = 0;
@@ -62,5 +67,6 @@ private:
62
67
  bool stopped_words = false;
63
68
  std::string stopping_word;
64
69
  bool stopped_limited = false;
70
+ std::vector<llama_token> audio_tokens;
65
71
  } _result;
66
72
  };
@@ -917,7 +917,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
917
917
 
918
918
  auto *worker =
919
919
  new LlamaCompletionWorker(info, _sess, callback, params, stop_words,
920
- chat_format, thinking_forced_open, reasoning_format, media_paths, guide_tokens);
920
+ chat_format, thinking_forced_open, reasoning_format, media_paths, guide_tokens,
921
+ _has_vocoder, _tts_type);
921
922
  worker->Queue();
922
923
  _wip = worker;
923
924
  worker->OnComplete([this]() { _wip = nullptr; });
@@ -86,8 +86,7 @@ if (LLAMA_CURL)
86
86
  endif()
87
87
  target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
88
88
  include_directories(${CURL_INCLUDE_DIRS})
89
- find_library(CURL_LIBRARY curl REQUIRED)
90
- set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
89
+ set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
91
90
  endif ()
92
91
 
93
92
  if (LLAMA_LLGUIDANCE)
@@ -112,13 +111,13 @@ if (LLAMA_LLGUIDANCE)
112
111
 
113
112
  ExternalProject_Add(llguidance_ext
114
113
  GIT_REPOSITORY https://github.com/guidance-ai/llguidance
115
- # v0.7.20 (+ fix to build on GCC 15):
116
- GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8
114
+ # v1.0.1:
115
+ GIT_TAG d795912fedc7d393de740177ea9ea761e7905774
117
116
  PREFIX ${CMAKE_BINARY_DIR}/llguidance
118
117
  SOURCE_DIR ${LLGUIDANCE_SRC}
119
118
  BUILD_IN_SOURCE TRUE
120
119
  CONFIGURE_COMMAND ""
121
- BUILD_COMMAND cargo build --release
120
+ BUILD_COMMAND cargo build --release --package llguidance
122
121
  INSTALL_COMMAND ""
123
122
  BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
124
123
  UPDATE_COMMAND ""
@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
1464
1464
  params.swa_full = true;
1465
1465
  }
1466
1466
  ).set_env("LLAMA_ARG_SWA_FULL"));
1467
+ add_opt(common_arg(
1468
+ {"--kv-unified", "-kvu"},
1469
+ string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
1470
+ "[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
1471
+ [](common_params & params) {
1472
+ params.kv_unified = true;
1473
+ }
1474
+ ).set_env("LLAMA_ARG_KV_SPLIT"));
1467
1475
  add_opt(common_arg(
1468
1476
  {"--no-context-shift"},
1469
1477
  string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
@@ -3423,5 +3431,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
3423
3431
  }
3424
3432
  ).set_examples({LLAMA_EXAMPLE_SERVER}));
3425
3433
 
3434
+ // diffusion parameters
3435
+ add_opt(common_arg(
3436
+ { "--diffusion-steps" }, "N",
3437
+ string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
3438
+ [](common_params & params, int value) { params.diffusion.steps = value; }
3439
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3440
+ add_opt(common_arg(
3441
+ { "--diffusion-eps" }, "F",
3442
+ string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
3443
+ [](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
3444
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3445
+ add_opt(common_arg(
3446
+ { "--diffusion-algorithm" }, "N",
3447
+ string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
3448
+ params.diffusion.algorithm),
3449
+ [](common_params & params, int value) { params.diffusion.algorithm = value; }
3450
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3451
+ add_opt(common_arg(
3452
+ { "--diffusion-alg-temp" }, "F",
3453
+ string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
3454
+ [](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
3455
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3456
+ add_opt(common_arg(
3457
+ { "--diffusion-visual" },
3458
+ string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3459
+ params.diffusion.visual_mode ? "true" : "false"),
3460
+ [](common_params & params) { params.diffusion.visual_mode = true; }
3461
+ ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3462
+
3426
3463
  return ctx_arg;
3427
3464
  }
@@ -448,6 +448,15 @@ void string_replace_all(std::string & s, const std::string & search, const std::
448
448
  bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
449
449
  return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
450
450
  }
451
+
452
+ bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
453
+ bool has_suffix = string_ends_with(str, suffix);
454
+ if (has_suffix) {
455
+ str = str.substr(0, str.size() - suffix.size());
456
+ }
457
+ return has_suffix;
458
+ }
459
+
451
460
  size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
452
461
  if (!str.empty() && !stop.empty()) {
453
462
  const char text_last_char = str.back();
@@ -1005,15 +1014,21 @@ struct common_init_result common_init_from_params(common_params & params) {
1005
1014
  params.sampling.ignore_eos = false;
1006
1015
  }
1007
1016
 
1008
- if (params.sampling.ignore_eos) {
1009
- for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1010
- if (llama_vocab_is_eog(vocab, i)) {
1011
- LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1012
- params.sampling.logit_bias.push_back({i, -INFINITY});
1013
- }
1017
+ // initialize once
1018
+ for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1019
+ if (llama_vocab_is_eog(vocab, i)) {
1020
+ LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1021
+ params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1014
1022
  }
1015
1023
  }
1016
1024
 
1025
+ if (params.sampling.ignore_eos) {
1026
+ // add EOG biases to the active set of logit biases
1027
+ params.sampling.logit_bias.insert(
1028
+ params.sampling.logit_bias.end(),
1029
+ params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1030
+ }
1031
+
1017
1032
  if (params.sampling.penalty_last_n == -1) {
1018
1033
  LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1019
1034
  params.sampling.penalty_last_n = llama_n_ctx(lctx);
@@ -1158,6 +1173,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
1158
1173
  cparams.no_perf = params.no_perf;
1159
1174
  cparams.op_offload = !params.no_op_offload;
1160
1175
  cparams.swa_full = params.swa_full;
1176
+ cparams.kv_unified = params.kv_unified;
1161
1177
 
1162
1178
  cparams.type_k = params.cache_type_k;
1163
1179
  cparams.type_v = params.cache_type_v;
@@ -81,6 +81,7 @@ enum llama_example {
81
81
  LLAMA_EXAMPLE_LOOKUP,
82
82
  LLAMA_EXAMPLE_PARALLEL,
83
83
  LLAMA_EXAMPLE_TTS,
84
+ LLAMA_EXAMPLE_DIFFUSION,
84
85
 
85
86
  LLAMA_EXAMPLE_COUNT,
86
87
  };
@@ -177,7 +178,8 @@ struct common_params_sampling {
177
178
  std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
178
179
  std::set<llama_token> preserved_tokens;
179
180
 
180
- std::vector<llama_logit_bias> logit_bias; // logit biases to apply
181
+ std::vector<llama_logit_bias> logit_bias; // logit biases to apply
182
+ std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
181
183
 
182
184
  // print the parameters into a string
183
185
  std::string print() const;
@@ -217,6 +219,14 @@ struct common_params_vocoder {
217
219
  bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
218
220
  };
219
221
 
222
+ struct common_params_diffusion {
223
+ int32_t steps = 64; // number of diffusion steps
224
+ float eps = 1e-3f; // epsilon for timesteps
225
+ int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
226
+ float alg_temp = 0.0f; // algorithm temperature
227
+ bool visual_mode = false; // show progressive diffusion on screen
228
+ };
229
+
220
230
  enum common_reasoning_format {
221
231
  COMMON_REASONING_FORMAT_NONE,
222
232
  COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
@@ -269,6 +279,7 @@ struct common_params {
269
279
  struct common_params_sampling sampling;
270
280
  struct common_params_speculative speculative;
271
281
  struct common_params_vocoder vocoder;
282
+ struct common_params_diffusion diffusion;
272
283
 
273
284
  struct common_params_model model;
274
285
 
@@ -331,6 +342,7 @@ struct common_params {
331
342
  bool no_perf = false; // disable performance metrics
332
343
  bool ctx_shift = true; // context shift on inifinite text generation
333
344
  bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
345
+ bool kv_unified = false; // enable unified KV cache
334
346
 
335
347
  bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
336
348
  bool use_mmap = true; // use mmap for faster loads
@@ -523,6 +535,7 @@ static bool string_starts_with(const std::string & str,
523
535
 
524
536
  // While we wait for C++20's std::string::ends_with...
525
537
  bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
538
+ bool string_remove_suffix(std::string & str, const std::string_view & suffix);
526
539
  size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
527
540
 
528
541
  bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
@@ -181,6 +181,8 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou
181
181
  option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
182
182
  option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
183
183
  option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
184
+ option(GGML_WEBGPU "ggml: use WebGPU" OFF)
185
+ option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
184
186
  option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
185
187
  option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
186
188
  option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
@@ -270,6 +272,7 @@ set(GGML_PUBLIC_HEADERS
270
272
  include/ggml-rpc.h
271
273
  include/ggml-sycl.h
272
274
  include/ggml-vulkan.h
275
+ include/ggml-webgpu.h
273
276
  include/gguf.h)
274
277
 
275
278
  set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
@@ -0,0 +1,19 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-backend.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ #define GGML_WEBGPU_NAME "WebGPU"
11
+
12
+ // Needed for examples in ggml
13
+ GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
14
+
15
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
16
+
17
+ #ifdef __cplusplus
18
+ }
19
+ #endif
@@ -1297,6 +1297,19 @@ extern "C" {
1297
1297
  struct ggml_tensor * a,
1298
1298
  float s);
1299
1299
 
1300
+ // x = s * a + b
1301
+ GGML_API struct ggml_tensor * ggml_scale_bias(
1302
+ struct ggml_context * ctx,
1303
+ struct ggml_tensor * a,
1304
+ float s,
1305
+ float b);
1306
+
1307
+ GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
1308
+ struct ggml_context * ctx,
1309
+ struct ggml_tensor * a,
1310
+ float s,
1311
+ float b);
1312
+
1300
1313
  // b -> view(a,offset,nb1,nb2,3), return modified a
1301
1314
  GGML_API struct ggml_tensor * ggml_set(
1302
1315
  struct ggml_context * ctx,
@@ -370,6 +370,7 @@ ggml_add_backend(MUSA)
370
370
  ggml_add_backend(RPC)
371
371
  ggml_add_backend(SYCL)
372
372
  ggml_add_backend(Vulkan)
373
+ ggml_add_backend(WebGPU)
373
374
  ggml_add_backend(OpenCL)
374
375
 
375
376
  foreach (target ggml-base ggml)