@fugood/llama.node 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/CMakeLists.txt +40 -5
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +46 -0
  17. package/lib/index.js +18 -0
  18. package/lib/index.ts +24 -0
  19. package/package.json +4 -1
  20. package/patches/node-api-headers+1.1.0.patch +26 -0
  21. package/src/DecodeAudioTokenWorker.cpp +40 -0
  22. package/src/DecodeAudioTokenWorker.h +22 -0
  23. package/src/EmbeddingWorker.cpp +7 -5
  24. package/src/LlamaCompletionWorker.cpp +64 -50
  25. package/src/LlamaCompletionWorker.h +6 -7
  26. package/src/LlamaContext.cpp +523 -224
  27. package/src/LlamaContext.h +25 -4
  28. package/src/LoadSessionWorker.cpp +4 -2
  29. package/src/SaveSessionWorker.cpp +10 -6
  30. package/src/TokenizeWorker.cpp +10 -5
  31. package/src/addons.cc +8 -11
  32. package/src/common.hpp +92 -93
  33. package/src/tts_utils.cpp +346 -0
  34. package/src/tts_utils.h +62 -0
  35. package/src/win_dynamic_load.c +2102 -0
  36. package/bin/win32/arm64/llama-node.node +0 -0
  37. package/bin/win32/arm64/node.lib +0 -0
@@ -1,9 +1,5 @@
1
- #include "ggml.h"
2
- #include "gguf.h"
3
- #include "llama-impl.h"
4
- #include "json.hpp"
5
- #include "json-schema-to-grammar.h"
6
1
  #include "LlamaContext.h"
2
+ #include "DecodeAudioTokenWorker.h"
7
3
  #include "DetokenizeWorker.h"
8
4
  #include "DisposeWorker.h"
9
5
  #include "EmbeddingWorker.h"
@@ -11,33 +7,42 @@
11
7
  #include "LoadSessionWorker.h"
12
8
  #include "SaveSessionWorker.h"
13
9
  #include "TokenizeWorker.h"
10
+ #include "ggml.h"
11
+ #include "gguf.h"
12
+ #include "json-schema-to-grammar.h"
13
+ #include "json.hpp"
14
+ #include "llama-impl.h"
14
15
 
16
+ #include <atomic>
15
17
  #include <mutex>
16
18
  #include <queue>
17
- #include <atomic>
18
19
 
19
20
  // Helper function for formatted strings (for console logs)
20
- template<typename ... Args>
21
- static std::string format_string(const std::string& format, Args ... args) {
22
- int size_s = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; // +1 for null terminator
23
- if (size_s <= 0) { return "Error formatting string"; }
24
- auto size = static_cast<size_t>(size_s);
25
- std::unique_ptr<char[]> buf(new char[size]);
26
- std::snprintf(buf.get(), size, format.c_str(), args ...);
27
- return std::string(buf.get(), buf.get() + size - 1); // -1 to exclude null terminator
21
+ template <typename... Args>
22
+ static std::string format_string(const std::string &format, Args... args) {
23
+ int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) +
24
+ 1; // +1 for null terminator
25
+ if (size_s <= 0) {
26
+ return "Error formatting string";
27
+ }
28
+ auto size = static_cast<size_t>(size_s);
29
+ std::unique_ptr<char[]> buf(new char[size]);
30
+ std::snprintf(buf.get(), size, format.c_str(), args...);
31
+ return std::string(buf.get(),
32
+ buf.get() + size - 1); // -1 to exclude null terminator
28
33
  }
29
34
 
30
35
  using json = nlohmann::ordered_json;
31
36
 
32
37
  // loadModelInfo(path: string): object
33
- Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
38
+ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo &info) {
34
39
  Napi::Env env = info.Env();
35
40
  struct gguf_init_params params = {
36
- /*.no_alloc = */ false,
37
- /*.ctx = */ NULL,
41
+ /*.no_alloc = */ false,
42
+ /*.ctx = */ NULL,
38
43
  };
39
44
  std::string path = info[0].ToString().Utf8Value();
40
-
45
+
41
46
  // Convert Napi::Array to vector<string>
42
47
  std::vector<std::string> skip;
43
48
  if (info.Length() > 1 && info[1].IsArray()) {
@@ -47,7 +52,7 @@ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
47
52
  }
48
53
  }
49
54
 
50
- struct gguf_context * ctx = gguf_init_from_file(path.c_str(), params);
55
+ struct gguf_context *ctx = gguf_init_from_file(path.c_str(), params);
51
56
 
52
57
  Napi::Object metadata = Napi::Object::New(env);
53
58
  if (std::find(skip.begin(), skip.end(), "version") == skip.end()) {
@@ -57,7 +62,8 @@ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
57
62
  metadata.Set("alignment", Napi::Number::New(env, gguf_get_alignment(ctx)));
58
63
  }
59
64
  if (std::find(skip.begin(), skip.end(), "data_offset") == skip.end()) {
60
- metadata.Set("data_offset", Napi::Number::New(env, gguf_get_data_offset(ctx)));
65
+ metadata.Set("data_offset",
66
+ Napi::Number::New(env, gguf_get_data_offset(ctx)));
61
67
  }
62
68
 
63
69
  // kv
@@ -65,7 +71,7 @@ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
65
71
  const int n_kv = gguf_get_n_kv(ctx);
66
72
 
67
73
  for (int i = 0; i < n_kv; ++i) {
68
- const char * key = gguf_get_key(ctx, i);
74
+ const char *key = gguf_get_key(ctx, i);
69
75
  if (std::find(skip.begin(), skip.end(), key) != skip.end()) {
70
76
  continue;
71
77
  }
@@ -138,6 +144,24 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
138
144
  static_cast<napi_property_attributes>(napi_enumerable)),
139
145
  InstanceMethod<&LlamaContext::GetMultimodalSupport>(
140
146
  "getMultimodalSupport",
147
+ static_cast<napi_property_attributes>(napi_enumerable)),
148
+ InstanceMethod<&LlamaContext::InitVocoder>(
149
+ "initVocoder",
150
+ static_cast<napi_property_attributes>(napi_enumerable)),
151
+ InstanceMethod<&LlamaContext::ReleaseVocoder>(
152
+ "releaseVocoder",
153
+ static_cast<napi_property_attributes>(napi_enumerable)),
154
+ InstanceMethod<&LlamaContext::IsVocoderEnabled>(
155
+ "isVocoderEnabled",
156
+ static_cast<napi_property_attributes>(napi_enumerable)),
157
+ InstanceMethod<&LlamaContext::GetFormattedAudioCompletion>(
158
+ "getFormattedAudioCompletion",
159
+ static_cast<napi_property_attributes>(napi_enumerable)),
160
+ InstanceMethod<&LlamaContext::GetAudioCompletionGuideTokens>(
161
+ "getAudioCompletionGuideTokens",
162
+ static_cast<napi_property_attributes>(napi_enumerable)),
163
+ InstanceMethod<&LlamaContext::DecodeAudioTokens>(
164
+ "decodeAudioTokens",
141
165
  static_cast<napi_property_attributes>(napi_enumerable))});
142
166
  Napi::FunctionReference *constructor = new Napi::FunctionReference();
143
167
  *constructor = Napi::Persistent(func);
@@ -148,19 +172,13 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
148
172
  }
149
173
 
150
174
  const std::vector<ggml_type> kv_cache_types = {
151
- GGML_TYPE_F32,
152
- GGML_TYPE_F16,
153
- GGML_TYPE_BF16,
154
- GGML_TYPE_Q8_0,
155
- GGML_TYPE_Q4_0,
156
- GGML_TYPE_Q4_1,
157
- GGML_TYPE_IQ4_NL,
158
- GGML_TYPE_Q5_0,
159
- GGML_TYPE_Q5_1,
175
+ GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
176
+ GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
177
+ GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
160
178
  };
161
179
 
162
- static ggml_type kv_cache_type_from_str(const std::string & s) {
163
- for (const auto & type : kv_cache_types) {
180
+ static ggml_type kv_cache_type_from_str(const std::string &s) {
181
+ for (const auto &type : kv_cache_types) {
164
182
  if (ggml_type_name(type) == s) {
165
183
  return type;
166
184
  }
@@ -168,12 +186,17 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
168
186
  throw std::runtime_error("Unsupported cache type: " + s);
169
187
  }
170
188
 
171
- static int32_t pooling_type_from_str(const std::string & s) {
172
- if (s == "none") return LLAMA_POOLING_TYPE_NONE;
173
- if (s == "mean") return LLAMA_POOLING_TYPE_MEAN;
174
- if (s == "cls") return LLAMA_POOLING_TYPE_CLS;
175
- if (s == "last") return LLAMA_POOLING_TYPE_LAST;
176
- if (s == "rank") return LLAMA_POOLING_TYPE_RANK;
189
+ static int32_t pooling_type_from_str(const std::string &s) {
190
+ if (s == "none")
191
+ return LLAMA_POOLING_TYPE_NONE;
192
+ if (s == "mean")
193
+ return LLAMA_POOLING_TYPE_MEAN;
194
+ if (s == "cls")
195
+ return LLAMA_POOLING_TYPE_CLS;
196
+ if (s == "last")
197
+ return LLAMA_POOLING_TYPE_LAST;
198
+ if (s == "rank")
199
+ return LLAMA_POOLING_TYPE_RANK;
177
200
  return LLAMA_POOLING_TYPE_UNSPECIFIED;
178
201
  }
179
202
 
@@ -200,7 +223,8 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
200
223
 
201
224
  params.chat_template = get_option<std::string>(options, "chat_template", "");
202
225
 
203
- std::string reasoning_format = get_option<std::string>(options, "reasoning_format", "none");
226
+ std::string reasoning_format =
227
+ get_option<std::string>(options, "reasoning_format", "none");
204
228
  if (reasoning_format == "deepseek") {
205
229
  params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
206
230
  } else {
@@ -216,16 +240,17 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
216
240
  params.n_ubatch = params.n_batch;
217
241
  }
218
242
  params.embd_normalize = get_option<int32_t>(options, "embd_normalize", 2);
219
- params.pooling_type = (enum llama_pooling_type) pooling_type_from_str(
220
- get_option<std::string>(options, "pooling_type", "").c_str()
221
- );
243
+ params.pooling_type = (enum llama_pooling_type)pooling_type_from_str(
244
+ get_option<std::string>(options, "pooling_type", "").c_str());
222
245
 
223
246
  params.cpuparams.n_threads =
224
247
  get_option<int32_t>(options, "n_threads", cpu_get_num_math() / 2);
225
248
  params.n_gpu_layers = get_option<int32_t>(options, "n_gpu_layers", -1);
226
249
  params.flash_attn = get_option<bool>(options, "flash_attn", false);
227
- params.cache_type_k = kv_cache_type_from_str(get_option<std::string>(options, "cache_type_k", "f16").c_str());
228
- params.cache_type_v = kv_cache_type_from_str(get_option<std::string>(options, "cache_type_v", "f16").c_str());
250
+ params.cache_type_k = kv_cache_type_from_str(
251
+ get_option<std::string>(options, "cache_type_k", "f16").c_str());
252
+ params.cache_type_v = kv_cache_type_from_str(
253
+ get_option<std::string>(options, "cache_type_v", "f16").c_str());
229
254
  params.ctx_shift = get_option<bool>(options, "ctx_shift", true);
230
255
 
231
256
  params.use_mlock = get_option<bool>(options, "use_mlock", false);
@@ -296,8 +321,9 @@ Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) {
296
321
  return Napi::String::New(info.Env(), _info);
297
322
  }
298
323
 
299
- bool validateModelChatTemplate(const struct llama_model * model, const bool use_jinja, const char * name) {
300
- const char * tmpl = llama_model_chat_template(model, name);
324
+ bool validateModelChatTemplate(const struct llama_model *model,
325
+ const bool use_jinja, const char *name) {
326
+ const char *tmpl = llama_model_chat_template(model, name);
301
327
  if (tmpl == nullptr) {
302
328
  return false;
303
329
  }
@@ -323,68 +349,68 @@ extern "C" void cleanup_logging();
323
349
  void LlamaContext::ToggleNativeLog(const Napi::CallbackInfo &info) {
324
350
  Napi::Env env = info.Env();
325
351
  bool enable = info[0].ToBoolean().Value();
326
-
352
+
327
353
  if (enable) {
328
354
  if (!info[1].IsFunction()) {
329
- Napi::TypeError::New(env, "Callback function required").ThrowAsJavaScriptException();
355
+ Napi::TypeError::New(env, "Callback function required")
356
+ .ThrowAsJavaScriptException();
330
357
  return;
331
358
  }
332
-
359
+
333
360
  // First clean up existing thread-safe function if any
334
361
  if (g_logging_enabled) {
335
362
  g_tsfn.Release();
336
363
  g_logging_enabled = false;
337
364
  }
338
-
365
+
339
366
  // Create thread-safe function that can be called from any thread
340
- g_tsfn = Napi::ThreadSafeFunction::New(
341
- env,
342
- info[1].As<Napi::Function>(),
343
- "LLAMA Logger",
344
- 0,
345
- 1,
346
- [](Napi::Env) {
347
- // Finalizer callback - nothing needed here
348
- }
349
- );
367
+ g_tsfn = Napi::ThreadSafeFunction::New(env, info[1].As<Napi::Function>(),
368
+ "LLAMA Logger", 0, 1, [](Napi::Env) {
369
+ // Finalizer callback - nothing
370
+ // needed here
371
+ });
350
372
 
351
373
  g_logging_enabled = true;
352
-
374
+
353
375
  // Set up log callback
354
- llama_log_set([](ggml_log_level level, const char* text, void* user_data) {
355
- // First call the default logger
356
- llama_log_callback_default(level, text, user_data);
357
-
358
- if (!g_logging_enabled) return;
359
-
360
- // Determine log level string
361
- std::string level_str = "";
362
- if (level == GGML_LOG_LEVEL_ERROR) {
363
- level_str = "error";
364
- } else if (level == GGML_LOG_LEVEL_INFO) {
365
- level_str = "info";
366
- } else if (level == GGML_LOG_LEVEL_WARN) {
367
- level_str = "warn";
368
- }
369
-
370
- // Create a heap-allocated copy of the data
371
- auto* data = new LogMessage{level_str, text};
372
-
373
- // Queue callback to be executed on the JavaScript thread
374
- auto status = g_tsfn.BlockingCall(data, [](Napi::Env env, Napi::Function jsCallback, LogMessage* data) {
375
- // This code runs on the JavaScript thread
376
- jsCallback.Call({
377
- Napi::String::New(env, data->level),
378
- Napi::String::New(env, data->text)
379
- });
380
- delete data;
381
- });
382
-
383
- // If the call failed (e.g., runtime is shutting down), clean up the data
384
- if (status != napi_ok) {
385
- delete data;
386
- }
387
- }, nullptr);
376
+ llama_log_set(
377
+ [](ggml_log_level level, const char *text, void *user_data) {
378
+ // First call the default logger
379
+ llama_log_callback_default(level, text, user_data);
380
+
381
+ if (!g_logging_enabled)
382
+ return;
383
+
384
+ // Determine log level string
385
+ std::string level_str = "";
386
+ if (level == GGML_LOG_LEVEL_ERROR) {
387
+ level_str = "error";
388
+ } else if (level == GGML_LOG_LEVEL_INFO) {
389
+ level_str = "info";
390
+ } else if (level == GGML_LOG_LEVEL_WARN) {
391
+ level_str = "warn";
392
+ }
393
+
394
+ // Create a heap-allocated copy of the data
395
+ auto *data = new LogMessage{level_str, text};
396
+
397
+ // Queue callback to be executed on the JavaScript thread
398
+ auto status = g_tsfn.BlockingCall(
399
+ data,
400
+ [](Napi::Env env, Napi::Function jsCallback, LogMessage *data) {
401
+ // This code runs on the JavaScript thread
402
+ jsCallback.Call({Napi::String::New(env, data->level),
403
+ Napi::String::New(env, data->text)});
404
+ delete data;
405
+ });
406
+
407
+ // If the call failed (e.g., runtime is shutting down), clean up the
408
+ // data
409
+ if (status != napi_ok) {
410
+ delete data;
411
+ }
412
+ },
413
+ nullptr);
388
414
  } else {
389
415
  // Disable logging
390
416
  if (g_logging_enabled) {
@@ -418,26 +444,51 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
418
444
  details.Set("size", llama_model_size(model));
419
445
 
420
446
  Napi::Object chatTemplates = Napi::Object::New(info.Env());
421
- chatTemplates.Set("llamaChat", validateModelChatTemplate(model, false, ""));
447
+ chatTemplates.Set("llamaChat", validateModelChatTemplate(model, false, nullptr));
422
448
  Napi::Object minja = Napi::Object::New(info.Env());
423
- minja.Set("default", validateModelChatTemplate(model, true, ""));
449
+ minja.Set("default", validateModelChatTemplate(model, true, nullptr));
424
450
  Napi::Object defaultCaps = Napi::Object::New(info.Env());
425
- defaultCaps.Set("tools", _templates.get()->template_default->original_caps().supports_tools);
426
- defaultCaps.Set("toolCalls", _templates.get()->template_default->original_caps().supports_tool_calls);
427
- defaultCaps.Set("toolResponses", _templates.get()->template_default->original_caps().supports_tool_responses);
428
- defaultCaps.Set("systemRole", _templates.get()->template_default->original_caps().supports_system_role);
429
- defaultCaps.Set("parallelToolCalls", _templates.get()->template_default->original_caps().supports_parallel_tool_calls);
430
- defaultCaps.Set("toolCallId", _templates.get()->template_default->original_caps().supports_tool_call_id);
451
+ defaultCaps.Set(
452
+ "tools",
453
+ _templates.get()->template_default->original_caps().supports_tools);
454
+ defaultCaps.Set(
455
+ "toolCalls",
456
+ _templates.get()->template_default->original_caps().supports_tool_calls);
457
+ defaultCaps.Set("toolResponses", _templates.get()
458
+ ->template_default->original_caps()
459
+ .supports_tool_responses);
460
+ defaultCaps.Set(
461
+ "systemRole",
462
+ _templates.get()->template_default->original_caps().supports_system_role);
463
+ defaultCaps.Set("parallelToolCalls", _templates.get()
464
+ ->template_default->original_caps()
465
+ .supports_parallel_tool_calls);
466
+ defaultCaps.Set("toolCallId", _templates.get()
467
+ ->template_default->original_caps()
468
+ .supports_tool_call_id);
431
469
  minja.Set("defaultCaps", defaultCaps);
432
470
  minja.Set("toolUse", validateModelChatTemplate(model, true, "tool_use"));
433
471
  if (_templates.get()->template_tool_use) {
434
472
  Napi::Object toolUseCaps = Napi::Object::New(info.Env());
435
- toolUseCaps.Set("tools", _templates.get()->template_tool_use->original_caps().supports_tools);
436
- toolUseCaps.Set("toolCalls", _templates.get()->template_tool_use->original_caps().supports_tool_calls);
437
- toolUseCaps.Set("toolResponses", _templates.get()->template_tool_use->original_caps().supports_tool_responses);
438
- toolUseCaps.Set("systemRole", _templates.get()->template_tool_use->original_caps().supports_system_role);
439
- toolUseCaps.Set("parallelToolCalls", _templates.get()->template_tool_use->original_caps().supports_parallel_tool_calls);
440
- toolUseCaps.Set("toolCallId", _templates.get()->template_tool_use->original_caps().supports_tool_call_id);
473
+ toolUseCaps.Set(
474
+ "tools",
475
+ _templates.get()->template_tool_use->original_caps().supports_tools);
476
+ toolUseCaps.Set("toolCalls", _templates.get()
477
+ ->template_tool_use->original_caps()
478
+ .supports_tool_calls);
479
+ toolUseCaps.Set("toolResponses", _templates.get()
480
+ ->template_tool_use->original_caps()
481
+ .supports_tool_responses);
482
+ toolUseCaps.Set("systemRole", _templates.get()
483
+ ->template_tool_use->original_caps()
484
+ .supports_system_role);
485
+ toolUseCaps.Set("parallelToolCalls",
486
+ _templates.get()
487
+ ->template_tool_use->original_caps()
488
+ .supports_parallel_tool_calls);
489
+ toolUseCaps.Set("toolCallId", _templates.get()
490
+ ->template_tool_use->original_caps()
491
+ .supports_tool_call_id);
441
492
  minja.Set("toolUseCaps", toolUseCaps);
442
493
  }
443
494
  chatTemplates.Set("minja", minja);
@@ -446,20 +497,17 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
446
497
  details.Set("metadata", metadata);
447
498
 
448
499
  // Deprecated: use chatTemplates.llamaChat instead
449
- details.Set("isChatTemplateSupported", validateModelChatTemplate(_sess->model(), false, ""));
500
+ details.Set("isChatTemplateSupported",
501
+ validateModelChatTemplate(_sess->model(), false, nullptr));
450
502
  return details;
451
503
  }
452
504
 
453
505
  common_chat_params getFormattedChatWithJinja(
454
- const std::shared_ptr<LlamaSession> &sess,
455
- const common_chat_templates_ptr &templates,
456
- const std::string &messages,
457
- const std::string &chat_template,
458
- const std::string &json_schema,
459
- const std::string &tools,
460
- const bool &parallel_tool_calls,
461
- const std::string &tool_choice
462
- ) {
506
+ const std::shared_ptr<LlamaSession> &sess,
507
+ const common_chat_templates_ptr &templates, const std::string &messages,
508
+ const std::string &chat_template, const std::string &json_schema,
509
+ const std::string &tools, const bool &parallel_tool_calls,
510
+ const std::string &tool_choice) {
463
511
  common_chat_templates_inputs inputs;
464
512
  inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
465
513
  auto useTools = !tools.empty();
@@ -473,23 +521,22 @@ common_chat_params getFormattedChatWithJinja(
473
521
  if (!json_schema.empty()) {
474
522
  inputs.json_schema = json::parse(json_schema);
475
523
  }
476
- inputs.extract_reasoning = sess->params().reasoning_format != COMMON_REASONING_FORMAT_NONE;
524
+ inputs.extract_reasoning =
525
+ sess->params().reasoning_format != COMMON_REASONING_FORMAT_NONE;
477
526
 
478
527
  // If chat_template is provided, create new one and use it (probably slow)
479
528
  if (!chat_template.empty()) {
480
- auto tmps = common_chat_templates_init(sess->model(), chat_template);
481
- return common_chat_templates_apply(tmps.get(), inputs);
529
+ auto tmps = common_chat_templates_init(sess->model(), chat_template);
530
+ return common_chat_templates_apply(tmps.get(), inputs);
482
531
  } else {
483
- return common_chat_templates_apply(templates.get(), inputs);
532
+ return common_chat_templates_apply(templates.get(), inputs);
484
533
  }
485
534
  }
486
535
 
487
- std::string getFormattedChat(
488
- const struct llama_model * model,
489
- const common_chat_templates_ptr &templates,
490
- const std::string &messages,
491
- const std::string &chat_template
492
- ) {
536
+ std::string getFormattedChat(const struct llama_model *model,
537
+ const common_chat_templates_ptr &templates,
538
+ const std::string &messages,
539
+ const std::string &chat_template) {
493
540
  common_chat_templates_inputs inputs;
494
541
  inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
495
542
  inputs.use_jinja = false;
@@ -506,7 +553,8 @@ std::string getFormattedChat(
506
553
  // getFormattedChat(
507
554
  // messages: [{ role: string, content: string }],
508
555
  // chat_template: string,
509
- // params: { jinja: boolean, json_schema: string, tools: string, parallel_tool_calls: boolean, tool_choice: string }
556
+ // params: { jinja: boolean, json_schema: string, tools: string,
557
+ // parallel_tool_calls: boolean, tool_choice: string }
510
558
  // ): object | string
511
559
  Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
512
560
  Napi::Env env = info.Env();
@@ -517,32 +565,44 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
517
565
  auto chat_template = info[1].IsString() ? info[1].ToString().Utf8Value() : "";
518
566
 
519
567
  auto has_params = info.Length() >= 2;
520
- auto params = has_params ? info[2].As<Napi::Object>() : Napi::Object::New(env);
568
+ auto params =
569
+ has_params ? info[2].As<Napi::Object>() : Napi::Object::New(env);
521
570
 
522
571
  if (get_option<bool>(params, "jinja", false)) {
523
572
  std::string json_schema_str = "";
524
573
  if (!is_nil(params.Get("response_format"))) {
525
574
  auto response_format = params.Get("response_format").As<Napi::Object>();
526
- auto response_format_type = get_option<std::string>(response_format, "type", "text");
527
- if (response_format_type == "json_schema" && response_format.Has("json_schema")) {
528
- auto json_schema = response_format.Get("json_schema").As<Napi::Object>();
529
- json_schema_str = json_schema.Has("schema") ?
530
- json_stringify(json_schema.Get("schema").As<Napi::Object>()) :
531
- "{}";
575
+ auto response_format_type =
576
+ get_option<std::string>(response_format, "type", "text");
577
+ if (response_format_type == "json_schema" &&
578
+ response_format.Has("json_schema")) {
579
+ auto json_schema =
580
+ response_format.Get("json_schema").As<Napi::Object>();
581
+ json_schema_str =
582
+ json_schema.Has("schema")
583
+ ? json_stringify(json_schema.Get("schema").As<Napi::Object>())
584
+ : "{}";
532
585
  } else if (response_format_type == "json_object") {
533
- json_schema_str = response_format.Has("schema") ?
534
- json_stringify(response_format.Get("schema").As<Napi::Object>()) :
535
- "{}";
586
+ json_schema_str =
587
+ response_format.Has("schema")
588
+ ? json_stringify(
589
+ response_format.Get("schema").As<Napi::Object>())
590
+ : "{}";
536
591
  }
537
592
  }
538
- auto tools_str = params.Has("tools") ?
539
- json_stringify(params.Get("tools").As<Napi::Array>()) :
540
- "";
541
- auto parallel_tool_calls = get_option<bool>(params, "parallel_tool_calls", false);
593
+ auto tools_str = params.Has("tools")
594
+ ? json_stringify(params.Get("tools").As<Napi::Array>())
595
+ : "";
596
+ auto parallel_tool_calls =
597
+ get_option<bool>(params, "parallel_tool_calls", false);
542
598
  auto tool_choice = get_option<std::string>(params, "tool_choice", "");
543
599
 
544
- auto chatParams = getFormattedChatWithJinja(_sess, _templates, messages, chat_template, json_schema_str, tools_str, parallel_tool_calls, tool_choice);
545
-
600
+ auto chatParams = getFormattedChatWithJinja(
601
+ _sess, _templates, messages, chat_template, json_schema_str, tools_str,
602
+ parallel_tool_calls, tool_choice);
603
+
604
+ console_log(env, std::string("format: ") + std::to_string(chatParams.format));
605
+
546
606
  Napi::Object result = Napi::Object::New(env);
547
607
  result.Set("prompt", chatParams.prompt);
548
608
  // chat_format: int
@@ -554,30 +614,33 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
554
614
  // grammar_triggers: [{ value: string, token: number }]
555
615
  Napi::Array grammar_triggers = Napi::Array::New(env);
556
616
  for (size_t i = 0; i < chatParams.grammar_triggers.size(); i++) {
557
- const auto & trigger = chatParams.grammar_triggers[i];
558
- Napi::Object triggerObj = Napi::Object::New(env);
559
- triggerObj.Set("type", Napi::Number::New(env, trigger.type));
560
- triggerObj.Set("value", Napi::String::New(env, trigger.value));
561
- triggerObj.Set("token", Napi::Number::New(env, trigger.token));
562
- grammar_triggers.Set(i, triggerObj);
617
+ const auto &trigger = chatParams.grammar_triggers[i];
618
+ Napi::Object triggerObj = Napi::Object::New(env);
619
+ triggerObj.Set("type", Napi::Number::New(env, trigger.type));
620
+ triggerObj.Set("value", Napi::String::New(env, trigger.value));
621
+ triggerObj.Set("token", Napi::Number::New(env, trigger.token));
622
+ grammar_triggers.Set(i, triggerObj);
563
623
  }
564
624
  result.Set("grammar_triggers", grammar_triggers);
565
625
  // preserved_tokens: string[]
566
626
  Napi::Array preserved_tokens = Napi::Array::New(env);
567
627
  for (size_t i = 0; i < chatParams.preserved_tokens.size(); i++) {
568
- preserved_tokens.Set(i, Napi::String::New(env, chatParams.preserved_tokens[i].c_str()));
628
+ preserved_tokens.Set(
629
+ i, Napi::String::New(env, chatParams.preserved_tokens[i].c_str()));
569
630
  }
570
631
  result.Set("preserved_tokens", preserved_tokens);
571
632
  // additional_stops: string[]
572
633
  Napi::Array additional_stops = Napi::Array::New(env);
573
634
  for (size_t i = 0; i < chatParams.additional_stops.size(); i++) {
574
- additional_stops.Set(i, Napi::String::New(env, chatParams.additional_stops[i].c_str()));
635
+ additional_stops.Set(
636
+ i, Napi::String::New(env, chatParams.additional_stops[i].c_str()));
575
637
  }
576
638
  result.Set("additional_stops", additional_stops);
577
639
 
578
640
  return result;
579
641
  } else {
580
- auto formatted = getFormattedChat(_sess->model(), _templates, messages, chat_template);
642
+ auto formatted =
643
+ getFormattedChat(_sess->model(), _templates, messages, chat_template);
581
644
  return Napi::String::New(env, formatted);
582
645
  }
583
646
  }
@@ -625,7 +688,9 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
625
688
 
626
689
  // Check if multimodal is enabled when media_paths are provided
627
690
  if (!media_paths.empty() && !(_has_multimodal && _mtmd_ctx != nullptr)) {
628
- Napi::Error::New(env, "Multimodal support must be enabled via initMultimodal to use media_paths").ThrowAsJavaScriptException();
691
+ Napi::Error::New(env, "Multimodal support must be enabled via "
692
+ "initMultimodal to use media_paths")
693
+ .ThrowAsJavaScriptException();
629
694
  return env.Undefined();
630
695
  }
631
696
 
@@ -641,16 +706,20 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
641
706
  std::string json_schema_str = "";
642
707
  if (options.Has("response_format")) {
643
708
  auto response_format = options.Get("response_format").As<Napi::Object>();
644
- auto response_format_type = get_option<std::string>(response_format, "type", "text");
645
- if (response_format_type == "json_schema" && response_format.Has("json_schema")) {
709
+ auto response_format_type =
710
+ get_option<std::string>(response_format, "type", "text");
711
+ if (response_format_type == "json_schema" &&
712
+ response_format.Has("json_schema")) {
646
713
  auto json_schema = response_format.Get("json_schema").As<Napi::Object>();
647
- json_schema_str = json_schema.Has("schema") ?
648
- json_stringify(json_schema.Get("schema").As<Napi::Object>()) :
649
- "{}";
714
+ json_schema_str =
715
+ json_schema.Has("schema")
716
+ ? json_stringify(json_schema.Get("schema").As<Napi::Object>())
717
+ : "{}";
650
718
  } else if (response_format_type == "json_object") {
651
- json_schema_str = response_format.Has("schema") ?
652
- json_stringify(response_format.Get("schema").As<Napi::Object>()) :
653
- "{}";
719
+ json_schema_str =
720
+ response_format.Has("schema")
721
+ ? json_stringify(response_format.Get("schema").As<Napi::Object>())
722
+ : "{}";
654
723
  }
655
724
  }
656
725
 
@@ -659,7 +728,9 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
659
728
  auto preserved_tokens = options.Get("preserved_tokens").As<Napi::Array>();
660
729
  for (size_t i = 0; i < preserved_tokens.Length(); i++) {
661
730
  auto token = preserved_tokens.Get(i).ToString().Utf8Value();
662
- auto ids = common_tokenize(_sess->context(), token, /* add_special= */ false, /* parse_special= */ true);
731
+ auto ids =
732
+ common_tokenize(_sess->context(), token, /* add_special= */ false,
733
+ /* parse_special= */ true);
663
734
  if (ids.size() == 1) {
664
735
  params.sampling.preserved_tokens.insert(ids[0]);
665
736
  }
@@ -672,15 +743,22 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
672
743
  for (size_t i = 0; i < grammar_triggers.Length(); i++) {
673
744
  auto trigger_obj = grammar_triggers.Get(i).As<Napi::Object>();
674
745
 
675
- auto type = static_cast<common_grammar_trigger_type>(trigger_obj.Get("type").ToNumber().Int32Value());
746
+ auto type = static_cast<common_grammar_trigger_type>(
747
+ trigger_obj.Get("type").ToNumber().Int32Value());
676
748
  auto word = trigger_obj.Get("value").ToString().Utf8Value();
677
749
 
678
750
  if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
679
- auto ids = common_tokenize(_sess->context(), word, /* add_special= */ false, /* parse_special= */ true);
751
+ auto ids =
752
+ common_tokenize(_sess->context(), word, /* add_special= */ false,
753
+ /* parse_special= */ true);
680
754
  if (ids.size() == 1) {
681
755
  auto token = ids[0];
682
- if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
683
- throw std::runtime_error("Grammar trigger word should be marked as preserved token");
756
+ if (std::find(params.sampling.preserved_tokens.begin(),
757
+ params.sampling.preserved_tokens.end(),
758
+ (llama_token)token) ==
759
+ params.sampling.preserved_tokens.end()) {
760
+ throw std::runtime_error(
761
+ "Grammar trigger word should be marked as preserved token");
684
762
  }
685
763
  common_grammar_trigger trigger;
686
764
  trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
@@ -688,14 +766,16 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
688
766
  trigger.token = token;
689
767
  params.sampling.grammar_triggers.push_back(std::move(trigger));
690
768
  } else {
691
- params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
769
+ params.sampling.grammar_triggers.push_back(
770
+ {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
692
771
  }
693
772
  } else {
694
773
  common_grammar_trigger trigger;
695
774
  trigger.type = type;
696
775
  trigger.value = word;
697
776
  if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
698
- auto token = (llama_token) trigger_obj.Get("token").ToNumber().Int32Value();
777
+ auto token =
778
+ (llama_token)trigger_obj.Get("token").ToNumber().Int32Value();
699
779
  trigger.token = token;
700
780
  }
701
781
  params.sampling.grammar_triggers.push_back(std::move(trigger));
@@ -705,7 +785,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
705
785
 
706
786
  // Handle grammar_lazy from options
707
787
  if (options.Has("grammar_lazy")) {
708
- params.sampling.grammar_lazy = options.Get("grammar_lazy").ToBoolean().Value();
788
+ params.sampling.grammar_lazy =
789
+ options.Get("grammar_lazy").ToBoolean().Value();
709
790
  }
710
791
 
711
792
  if (options.Has("messages") && options.Get("messages").IsArray()) {
@@ -713,29 +794,27 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
713
794
  auto chat_template = get_option<std::string>(options, "chat_template", "");
714
795
  auto jinja = get_option<bool>(options, "jinja", false);
715
796
  if (jinja) {
716
- auto tools_str = options.Has("tools") ?
717
- json_stringify(options.Get("tools").As<Napi::Array>()) :
718
- "";
719
- auto parallel_tool_calls = get_option<bool>(options, "parallel_tool_calls", false);
720
- auto tool_choice = get_option<std::string>(options, "tool_choice", "none");
797
+ auto tools_str =
798
+ options.Has("tools")
799
+ ? json_stringify(options.Get("tools").As<Napi::Array>())
800
+ : "";
801
+ auto parallel_tool_calls =
802
+ get_option<bool>(options, "parallel_tool_calls", false);
803
+ auto tool_choice =
804
+ get_option<std::string>(options, "tool_choice", "none");
721
805
 
722
806
  auto chatParams = getFormattedChatWithJinja(
723
- _sess,
724
- _templates,
725
- json_stringify(messages),
726
- chat_template,
727
- json_schema_str,
728
- tools_str,
729
- parallel_tool_calls,
730
- tool_choice
731
- );
732
-
807
+ _sess, _templates, json_stringify(messages), chat_template,
808
+ json_schema_str, tools_str, parallel_tool_calls, tool_choice);
809
+
733
810
  params.prompt = chatParams.prompt;
734
811
 
735
812
  chat_format = chatParams.format;
736
813
 
737
- for (const auto & token : chatParams.preserved_tokens) {
738
- auto ids = common_tokenize(_sess->context(), token, /* add_special= */ false, /* parse_special= */ true);
814
+ for (const auto &token : chatParams.preserved_tokens) {
815
+ auto ids =
816
+ common_tokenize(_sess->context(), token, /* add_special= */ false,
817
+ /* parse_special= */ true);
739
818
  if (ids.size() == 1) {
740
819
  params.sampling.preserved_tokens.insert(ids[0]);
741
820
  }
@@ -745,22 +824,18 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
745
824
  // grammar param always wins jinja template & json_schema
746
825
  params.sampling.grammar = chatParams.grammar;
747
826
  params.sampling.grammar_lazy = chatParams.grammar_lazy;
748
- for (const auto & trigger : chatParams.grammar_triggers) {
827
+ for (const auto &trigger : chatParams.grammar_triggers) {
749
828
  params.sampling.grammar_triggers.push_back(trigger);
750
829
  }
751
830
  has_grammar_set = true;
752
831
  }
753
-
754
- for (const auto & stop : chatParams.additional_stops) {
832
+
833
+ for (const auto &stop : chatParams.additional_stops) {
755
834
  stop_words.push_back(stop);
756
835
  }
757
836
  } else {
758
837
  auto formatted = getFormattedChat(
759
- _sess->model(),
760
- _templates,
761
- json_stringify(messages),
762
- chat_template
763
- );
838
+ _sess->model(), _templates, json_stringify(messages), chat_template);
764
839
  params.prompt = formatted;
765
840
  }
766
841
  } else {
@@ -772,7 +847,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
772
847
  }
773
848
 
774
849
  if (!has_grammar_set && !json_schema_str.empty()) {
775
- params.sampling.grammar = json_schema_to_grammar(json::parse(json_schema_str));
850
+ params.sampling.grammar =
851
+ json_schema_to_grammar(json::parse(json_schema_str));
776
852
  }
777
853
 
778
854
  params.n_predict = get_option<int32_t>(options, "n_predict", -1);
@@ -794,16 +870,32 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
794
870
  params.sampling.penalty_present =
795
871
  get_option<float>(options, "penalty_present", 0.00f);
796
872
  params.sampling.typ_p = get_option<float>(options, "typical_p", 1.00f);
797
- params.sampling.xtc_threshold = get_option<float>(options, "xtc_threshold", 0.00f);
798
- params.sampling.xtc_probability = get_option<float>(options, "xtc_probability", 0.10f);
799
- params.sampling.dry_multiplier = get_option<float>(options, "dry_multiplier", 1.75f);
873
+ params.sampling.xtc_threshold =
874
+ get_option<float>(options, "xtc_threshold", 0.00f);
875
+ params.sampling.xtc_probability =
876
+ get_option<float>(options, "xtc_probability", 0.10f);
877
+ params.sampling.dry_multiplier =
878
+ get_option<float>(options, "dry_multiplier", 1.75f);
800
879
  params.sampling.dry_base = get_option<float>(options, "dry_base", 2);
801
- params.sampling.dry_allowed_length = get_option<float>(options, "dry_allowed_length", -1);
802
- params.sampling.dry_penalty_last_n = get_option<float>(options, "dry_penalty_last_n", 0);
803
- params.sampling.top_n_sigma = get_option<float>(options, "top_n_sigma", -1.0f);
880
+ params.sampling.dry_allowed_length =
881
+ get_option<float>(options, "dry_allowed_length", -1);
882
+ params.sampling.dry_penalty_last_n =
883
+ get_option<float>(options, "dry_penalty_last_n", 0);
884
+ params.sampling.top_n_sigma =
885
+ get_option<float>(options, "top_n_sigma", -1.0f);
804
886
  params.sampling.ignore_eos = get_option<bool>(options, "ignore_eos", false);
805
887
  params.n_keep = get_option<int32_t>(options, "n_keep", 0);
806
- params.sampling.seed = get_option<int32_t>(options, "seed", LLAMA_DEFAULT_SEED);
888
+ params.sampling.seed =
889
+ get_option<int32_t>(options, "seed", LLAMA_DEFAULT_SEED);
890
+
891
+ // guide_tokens
892
+ std::vector<llama_token> guide_tokens;
893
+ if (options.Has("guide_tokens")) {
894
+ auto guide_tokens_array = options.Get("guide_tokens").As<Napi::Array>();
895
+ for (size_t i = 0; i < guide_tokens_array.Length(); i++) {
896
+ guide_tokens.push_back(guide_tokens_array.Get(i).ToNumber().Int32Value());
897
+ }
898
+ }
807
899
 
808
900
  Napi::Function callback;
809
901
  if (info.Length() >= 2) {
@@ -811,7 +903,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
811
903
  }
812
904
 
813
905
  auto *worker =
814
- new LlamaCompletionWorker(info, _sess, callback, params, stop_words, chat_format, media_paths);
906
+ new LlamaCompletionWorker(info, _sess, callback, params, stop_words,
907
+ chat_format, media_paths, guide_tokens);
815
908
  worker->Queue();
816
909
  _wip = worker;
817
910
  worker->OnComplete([this]() { _wip = nullptr; });
@@ -965,7 +1058,8 @@ void LlamaContext::RemoveLoraAdapters(const Napi::CallbackInfo &info) {
965
1058
 
966
1059
  // getLoadedLoraAdapters(): Promise<{ count, lora_adapters: [{ path: string,
967
1060
  // scaled: number }] }>
968
- Napi::Value LlamaContext::GetLoadedLoraAdapters(const Napi::CallbackInfo &info) {
1061
+ Napi::Value
1062
+ LlamaContext::GetLoadedLoraAdapters(const Napi::CallbackInfo &info) {
969
1063
  Napi::Env env = info.Env();
970
1064
  Napi::Array lora_adapters = Napi::Array::New(env, _lora.size());
971
1065
  for (size_t i = 0; i < _lora.size(); i++) {
@@ -983,18 +1077,18 @@ Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) {
983
1077
  if (_wip != nullptr) {
984
1078
  _wip->SetStop();
985
1079
  }
986
-
1080
+
987
1081
  if (_sess == nullptr) {
988
1082
  auto promise = Napi::Promise::Deferred(env);
989
1083
  promise.Resolve(env.Undefined());
990
1084
  return promise.Promise();
991
1085
  }
992
-
1086
+
993
1087
  // Clear the mtmd context reference in the session
994
1088
  if (_mtmd_ctx != nullptr) {
995
1089
  _sess->set_mtmd_ctx(nullptr);
996
1090
  }
997
-
1091
+
998
1092
  auto *worker = new DisposeWorker(info, std::move(_sess));
999
1093
  worker->Queue();
1000
1094
  return worker->Promise();
@@ -1022,7 +1116,8 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1022
1116
  Napi::Env env = info.Env();
1023
1117
 
1024
1118
  if (info.Length() < 1 || !info[0].IsObject()) {
1025
- Napi::TypeError::New(env, "Object expected for mmproj path").ThrowAsJavaScriptException();
1119
+ Napi::TypeError::New(env, "Object expected for mmproj path")
1120
+ .ThrowAsJavaScriptException();
1026
1121
  }
1027
1122
 
1028
1123
  auto options = info[0].As<Napi::Object>();
@@ -1030,7 +1125,8 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1030
1125
  auto use_gpu = options.Get("use_gpu").ToBoolean().Value();
1031
1126
 
1032
1127
  if (mmproj_path.empty()) {
1033
- Napi::TypeError::New(env, "mmproj path is required").ThrowAsJavaScriptException();
1128
+ Napi::TypeError::New(env, "mmproj path is required")
1129
+ .ThrowAsJavaScriptException();
1034
1130
  }
1035
1131
 
1036
1132
  console_log(env, "Initializing multimodal with mmproj path: " + mmproj_path);
@@ -1055,48 +1151,55 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1055
1151
  mtmd_params.n_threads = _sess->params().cpuparams.n_threads;
1056
1152
  mtmd_params.verbosity = (ggml_log_level)GGML_LOG_LEVEL_INFO;
1057
1153
 
1058
- console_log(env, format_string("Initializing mtmd context with threads=%d, use_gpu=%d",
1059
- mtmd_params.n_threads, mtmd_params.use_gpu ? 1 : 0));
1154
+ console_log(env, format_string(
1155
+ "Initializing mtmd context with threads=%d, use_gpu=%d",
1156
+ mtmd_params.n_threads, mtmd_params.use_gpu ? 1 : 0));
1060
1157
 
1061
1158
  _mtmd_ctx = mtmd_init_from_file(mmproj_path.c_str(), model, mtmd_params);
1062
1159
  if (_mtmd_ctx == nullptr) {
1063
- Napi::Error::New(env, "Failed to initialize multimodal context").ThrowAsJavaScriptException();
1160
+ Napi::Error::New(env, "Failed to initialize multimodal context")
1161
+ .ThrowAsJavaScriptException();
1064
1162
  return Napi::Boolean::New(env, false);
1065
1163
  }
1066
1164
 
1067
1165
  _has_multimodal = true;
1068
-
1166
+
1069
1167
  // Share the mtmd context with the session
1070
1168
  _sess->set_mtmd_ctx(_mtmd_ctx);
1071
1169
 
1072
1170
  // Check if the model uses M-RoPE or non-causal attention
1073
1171
  bool uses_mrope = mtmd_decode_use_mrope(_mtmd_ctx);
1074
1172
  bool uses_non_causal = mtmd_decode_use_non_causal(_mtmd_ctx);
1075
- console_log(env, format_string("Model multimodal properties: uses_mrope=%d, uses_non_causal=%d",
1076
- uses_mrope ? 1 : 0, uses_non_causal ? 1 : 0));
1173
+ console_log(
1174
+ env, format_string(
1175
+ "Model multimodal properties: uses_mrope=%d, uses_non_causal=%d",
1176
+ uses_mrope ? 1 : 0, uses_non_causal ? 1 : 0));
1077
1177
 
1078
- console_log(env, "Multimodal context initialized successfully with mmproj: " + mmproj_path);
1178
+ console_log(env, "Multimodal context initialized successfully with mmproj: " +
1179
+ mmproj_path);
1079
1180
  return Napi::Boolean::New(env, true);
1080
1181
  }
1081
1182
 
1082
1183
  // isMultimodalEnabled(): boolean
1083
1184
  Napi::Value LlamaContext::IsMultimodalEnabled(const Napi::CallbackInfo &info) {
1084
- return Napi::Boolean::New(info.Env(), _has_multimodal && _mtmd_ctx != nullptr);
1185
+ return Napi::Boolean::New(info.Env(),
1186
+ _has_multimodal && _mtmd_ctx != nullptr);
1085
1187
  }
1086
1188
 
1087
1189
  // getMultimodalSupport(): Promise<{ vision: boolean, audio: boolean }>
1088
1190
  Napi::Value LlamaContext::GetMultimodalSupport(const Napi::CallbackInfo &info) {
1089
1191
  Napi::Env env = info.Env();
1090
1192
  auto result = Napi::Object::New(env);
1091
-
1193
+
1092
1194
  if (_has_multimodal && _mtmd_ctx != nullptr) {
1093
- result.Set("vision", Napi::Boolean::New(env, mtmd_support_vision(_mtmd_ctx)));
1195
+ result.Set("vision",
1196
+ Napi::Boolean::New(env, mtmd_support_vision(_mtmd_ctx)));
1094
1197
  result.Set("audio", Napi::Boolean::New(env, mtmd_support_audio(_mtmd_ctx)));
1095
1198
  } else {
1096
1199
  result.Set("vision", Napi::Boolean::New(env, false));
1097
1200
  result.Set("audio", Napi::Boolean::New(env, false));
1098
1201
  }
1099
-
1202
+
1100
1203
  return result;
1101
1204
  }
1102
1205
 
@@ -1107,10 +1210,206 @@ void LlamaContext::ReleaseMultimodal(const Napi::CallbackInfo &info) {
1107
1210
  if (_sess != nullptr) {
1108
1211
  _sess->set_mtmd_ctx(nullptr);
1109
1212
  }
1110
-
1213
+
1111
1214
  // Free the mtmd context
1112
1215
  mtmd_free(_mtmd_ctx);
1113
1216
  _mtmd_ctx = nullptr;
1114
1217
  _has_multimodal = false;
1115
1218
  }
1116
1219
  }
1220
+
1221
+ tts_type LlamaContext::getTTSType(Napi::Env env, nlohmann::json speaker) {
1222
+ if (speaker.is_object() && speaker.contains("version")) {
1223
+ std::string version = speaker["version"].get<std::string>();
1224
+ if (version == "0.2") {
1225
+ return OUTETTS_V0_2;
1226
+ } else if (version == "0.3") {
1227
+ return OUTETTS_V0_3;
1228
+ } else {
1229
+ Napi::Error::New(env, format_string("Unsupported speaker version '%s'\n",
1230
+ version.c_str()))
1231
+ .ThrowAsJavaScriptException();
1232
+ return UNKNOWN;
1233
+ }
1234
+ }
1235
+ if (_tts_type != UNKNOWN) {
1236
+ return _tts_type;
1237
+ }
1238
+ const char *chat_template =
1239
+ llama_model_chat_template(_sess->model(), nullptr);
1240
+ if (chat_template && std::string(chat_template) == "outetts-0.3") {
1241
+ return OUTETTS_V0_3;
1242
+ }
1243
+ return OUTETTS_V0_2;
1244
+ }
1245
+
1246
+ // initVocoder(path: string): boolean
1247
+ Napi::Value LlamaContext::InitVocoder(const Napi::CallbackInfo &info) {
1248
+ Napi::Env env = info.Env();
1249
+ if (info.Length() < 1 || !info[0].IsString()) {
1250
+ Napi::TypeError::New(env, "String expected for vocoder path")
1251
+ .ThrowAsJavaScriptException();
1252
+ }
1253
+ auto vocoder_path = info[0].ToString().Utf8Value();
1254
+ if (vocoder_path.empty()) {
1255
+ Napi::TypeError::New(env, "vocoder path is required")
1256
+ .ThrowAsJavaScriptException();
1257
+ }
1258
+ if (_has_vocoder) {
1259
+ Napi::Error::New(env, "Vocoder already initialized")
1260
+ .ThrowAsJavaScriptException();
1261
+ return Napi::Boolean::New(env, false);
1262
+ }
1263
+ _tts_type = getTTSType(env);
1264
+ _vocoder.params = _sess->params();
1265
+ _vocoder.params.warmup = false;
1266
+ _vocoder.params.model.path = vocoder_path;
1267
+ _vocoder.params.embedding = true;
1268
+ _vocoder.params.ctx_shift = false;
1269
+ _vocoder.params.n_ubatch = _vocoder.params.n_batch;
1270
+ common_init_result result = common_init_from_params(_vocoder.params);
1271
+ if (result.model == nullptr || result.context == nullptr) {
1272
+ Napi::Error::New(env, "Failed to initialize vocoder")
1273
+ .ThrowAsJavaScriptException();
1274
+ return Napi::Boolean::New(env, false);
1275
+ }
1276
+ _vocoder.model = std::move(result.model);
1277
+ _vocoder.context = std::move(result.context);
1278
+ _has_vocoder = true;
1279
+ return Napi::Boolean::New(env, true);
1280
+ }
1281
+
1282
+ // releaseVocoder(): void
1283
+ void LlamaContext::ReleaseVocoder(const Napi::CallbackInfo &info) {
1284
+ if (_has_vocoder) {
1285
+ _vocoder.model.reset();
1286
+ _vocoder.context.reset();
1287
+ _has_vocoder = false;
1288
+ }
1289
+ }
1290
+
1291
+ // isVocoderEnabled(): boolean
1292
+ Napi::Value LlamaContext::IsVocoderEnabled(const Napi::CallbackInfo &info) {
1293
+ Napi::Env env = info.Env();
1294
+ return Napi::Boolean::New(env, _has_vocoder);
1295
+ }
1296
+
1297
+ // getFormattedAudioCompletion(speaker: string|null, text: string): string
1298
+ Napi::Value
1299
+ LlamaContext::GetFormattedAudioCompletion(const Napi::CallbackInfo &info) {
1300
+ Napi::Env env = info.Env();
1301
+ if (info.Length() < 2 || !info[1].IsString()) {
1302
+ Napi::TypeError::New(env, "text parameter is required for audio completion")
1303
+ .ThrowAsJavaScriptException();
1304
+ }
1305
+ auto text = info[1].ToString().Utf8Value();
1306
+ auto speaker_json = info[0].IsString() ? info[0].ToString().Utf8Value() : "";
1307
+ nlohmann::json speaker =
1308
+ speaker_json.empty() ? nullptr : nlohmann::json::parse(speaker_json);
1309
+ const tts_type type = getTTSType(env, speaker);
1310
+ std::string audio_text = DEFAULT_AUDIO_TEXT;
1311
+ std::string audio_data = DEFAULT_AUDIO_DATA;
1312
+ if (type == OUTETTS_V0_3) {
1313
+ audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"),
1314
+ "<|space|>");
1315
+ audio_data =
1316
+ std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
1317
+ audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"),
1318
+ "<|space|>");
1319
+ }
1320
+ if (!speaker_json.empty()) {
1321
+ audio_text = audio_text_from_speaker(speaker, type);
1322
+ audio_data = audio_data_from_speaker(speaker, type);
1323
+ }
1324
+ return Napi::String::New(env, "<|im_start|>\n" + audio_text +
1325
+ process_text(text, type) +
1326
+ "<|text_end|>\n" + audio_data + "\n");
1327
+ }
1328
+
1329
+ // getAudioCompletionGuideTokens(text: string): Int32Array
1330
+ Napi::Value
1331
+ LlamaContext::GetAudioCompletionGuideTokens(const Napi::CallbackInfo &info) {
1332
+ Napi::Env env = info.Env();
1333
+ if (info.Length() < 1 || !info[0].IsString()) {
1334
+ Napi::TypeError::New(env,
1335
+ "String expected for audio completion guide tokens")
1336
+ .ThrowAsJavaScriptException();
1337
+ return env.Undefined();
1338
+ }
1339
+ auto text = info[0].ToString().Utf8Value();
1340
+ const tts_type type = getTTSType(env);
1341
+ auto clean_text = process_text(text, type);
1342
+ const std::string &delimiter =
1343
+ (type == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
1344
+ const llama_vocab *vocab = llama_model_get_vocab(_sess->model());
1345
+
1346
+ std::vector<int32_t> result;
1347
+ size_t start = 0;
1348
+ size_t end = clean_text.find(delimiter);
1349
+
1350
+ // first token is always a newline, as it was not previously added
1351
+ result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
1352
+
1353
+ while (end != std::string::npos) {
1354
+ std::string current_word = clean_text.substr(start, end - start);
1355
+ auto tmp = common_tokenize(vocab, current_word, false, true);
1356
+ result.push_back(tmp[0]);
1357
+ start = end + delimiter.length();
1358
+ end = clean_text.find(delimiter, start);
1359
+ }
1360
+
1361
+ // Add the last part
1362
+ std::string current_word = clean_text.substr(start);
1363
+ auto tmp = common_tokenize(vocab, current_word, false, true);
1364
+ if (tmp.size() > 0) {
1365
+ result.push_back(tmp[0]);
1366
+ }
1367
+ auto tokens = Napi::Int32Array::New(env, result.size());
1368
+ memcpy(tokens.Data(), result.data(), result.size() * sizeof(int32_t));
1369
+ return tokens;
1370
+ }
1371
+
1372
+ // decodeAudioTokens(tokens: number[]|Int32Array): Float32Array
1373
+ Napi::Value LlamaContext::DecodeAudioTokens(const Napi::CallbackInfo &info) {
1374
+ Napi::Env env = info.Env();
1375
+ if (info.Length() < 1) {
1376
+ Napi::TypeError::New(env, "Tokens parameter is required")
1377
+ .ThrowAsJavaScriptException();
1378
+ }
1379
+ std::vector<int32_t> tokens;
1380
+ if (info[0].IsTypedArray()) {
1381
+ auto js_tokens = info[0].As<Napi::Int32Array>();
1382
+ tokens.resize(js_tokens.ElementLength());
1383
+ memcpy(tokens.data(), js_tokens.Data(),
1384
+ js_tokens.ElementLength() * sizeof(int32_t));
1385
+ } else if (info[0].IsArray()) {
1386
+ auto js_tokens = info[0].As<Napi::Array>();
1387
+ for (size_t i = 0; i < js_tokens.Length(); i++) {
1388
+ tokens.push_back(js_tokens.Get(i).ToNumber().Int32Value());
1389
+ }
1390
+ } else {
1391
+ Napi::TypeError::New(env, "Tokens must be an number array or a Int32Array")
1392
+ .ThrowAsJavaScriptException();
1393
+ return env.Undefined();
1394
+ }
1395
+ tts_type type = getTTSType(env);
1396
+ if (type == UNKNOWN) {
1397
+ Napi::Error::New(env, "Unsupported audio tokens")
1398
+ .ThrowAsJavaScriptException();
1399
+ return env.Undefined();
1400
+ }
1401
+ if (type == OUTETTS_V0_3 || type == OUTETTS_V0_2) {
1402
+ tokens.erase(
1403
+ std::remove_if(tokens.begin(), tokens.end(),
1404
+ [](llama_token t) { return t < 151672 || t > 155772; }),
1405
+ tokens.end());
1406
+ for (auto &token : tokens) {
1407
+ token -= 151672;
1408
+ }
1409
+ }
1410
+ auto worker = new DecodeAudioTokenWorker(
1411
+ info, _vocoder.model.get(), _vocoder.context.get(),
1412
+ _sess->params().cpuparams.n_threads, tokens);
1413
+ worker->Queue();
1414
+ return worker->Promise();
1415
+ }