@fugood/llama.node 0.5.0 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,5 @@
1
- #include "ggml.h"
2
- #include "gguf.h"
3
- #include "llama-impl.h"
4
- #include "json.hpp"
5
- #include "json-schema-to-grammar.h"
6
1
  #include "LlamaContext.h"
2
+ #include "DecodeAudioTokenWorker.h"
7
3
  #include "DetokenizeWorker.h"
8
4
  #include "DisposeWorker.h"
9
5
  #include "EmbeddingWorker.h"
@@ -11,33 +7,42 @@
11
7
  #include "LoadSessionWorker.h"
12
8
  #include "SaveSessionWorker.h"
13
9
  #include "TokenizeWorker.h"
10
+ #include "ggml.h"
11
+ #include "gguf.h"
12
+ #include "json-schema-to-grammar.h"
13
+ #include "json.hpp"
14
+ #include "llama-impl.h"
14
15
 
16
+ #include <atomic>
15
17
  #include <mutex>
16
18
  #include <queue>
17
- #include <atomic>
18
19
 
19
20
  // Helper function for formatted strings (for console logs)
20
- template<typename ... Args>
21
- static std::string format_string(const std::string& format, Args ... args) {
22
- int size_s = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; // +1 for null terminator
23
- if (size_s <= 0) { return "Error formatting string"; }
24
- auto size = static_cast<size_t>(size_s);
25
- std::unique_ptr<char[]> buf(new char[size]);
26
- std::snprintf(buf.get(), size, format.c_str(), args ...);
27
- return std::string(buf.get(), buf.get() + size - 1); // -1 to exclude null terminator
21
+ template <typename... Args>
22
+ static std::string format_string(const std::string &format, Args... args) {
23
+ int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) +
24
+ 1; // +1 for null terminator
25
+ if (size_s <= 0) {
26
+ return "Error formatting string";
27
+ }
28
+ auto size = static_cast<size_t>(size_s);
29
+ std::unique_ptr<char[]> buf(new char[size]);
30
+ std::snprintf(buf.get(), size, format.c_str(), args...);
31
+ return std::string(buf.get(),
32
+ buf.get() + size - 1); // -1 to exclude null terminator
28
33
  }
29
34
 
30
35
  using json = nlohmann::ordered_json;
31
36
 
32
37
  // loadModelInfo(path: string): object
33
- Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
38
+ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo &info) {
34
39
  Napi::Env env = info.Env();
35
40
  struct gguf_init_params params = {
36
- /*.no_alloc = */ false,
37
- /*.ctx = */ NULL,
41
+ /*.no_alloc = */ false,
42
+ /*.ctx = */ NULL,
38
43
  };
39
44
  std::string path = info[0].ToString().Utf8Value();
40
-
45
+
41
46
  // Convert Napi::Array to vector<string>
42
47
  std::vector<std::string> skip;
43
48
  if (info.Length() > 1 && info[1].IsArray()) {
@@ -47,7 +52,7 @@ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
47
52
  }
48
53
  }
49
54
 
50
- struct gguf_context * ctx = gguf_init_from_file(path.c_str(), params);
55
+ struct gguf_context *ctx = gguf_init_from_file(path.c_str(), params);
51
56
 
52
57
  Napi::Object metadata = Napi::Object::New(env);
53
58
  if (std::find(skip.begin(), skip.end(), "version") == skip.end()) {
@@ -57,7 +62,8 @@ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
57
62
  metadata.Set("alignment", Napi::Number::New(env, gguf_get_alignment(ctx)));
58
63
  }
59
64
  if (std::find(skip.begin(), skip.end(), "data_offset") == skip.end()) {
60
- metadata.Set("data_offset", Napi::Number::New(env, gguf_get_data_offset(ctx)));
65
+ metadata.Set("data_offset",
66
+ Napi::Number::New(env, gguf_get_data_offset(ctx)));
61
67
  }
62
68
 
63
69
  // kv
@@ -65,7 +71,7 @@ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
65
71
  const int n_kv = gguf_get_n_kv(ctx);
66
72
 
67
73
  for (int i = 0; i < n_kv; ++i) {
68
- const char * key = gguf_get_key(ctx, i);
74
+ const char *key = gguf_get_key(ctx, i);
69
75
  if (std::find(skip.begin(), skip.end(), key) != skip.end()) {
70
76
  continue;
71
77
  }
@@ -138,6 +144,24 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
138
144
  static_cast<napi_property_attributes>(napi_enumerable)),
139
145
  InstanceMethod<&LlamaContext::GetMultimodalSupport>(
140
146
  "getMultimodalSupport",
147
+ static_cast<napi_property_attributes>(napi_enumerable)),
148
+ InstanceMethod<&LlamaContext::InitVocoder>(
149
+ "initVocoder",
150
+ static_cast<napi_property_attributes>(napi_enumerable)),
151
+ InstanceMethod<&LlamaContext::ReleaseVocoder>(
152
+ "releaseVocoder",
153
+ static_cast<napi_property_attributes>(napi_enumerable)),
154
+ InstanceMethod<&LlamaContext::IsVocoderEnabled>(
155
+ "isVocoderEnabled",
156
+ static_cast<napi_property_attributes>(napi_enumerable)),
157
+ InstanceMethod<&LlamaContext::GetFormattedAudioCompletion>(
158
+ "getFormattedAudioCompletion",
159
+ static_cast<napi_property_attributes>(napi_enumerable)),
160
+ InstanceMethod<&LlamaContext::GetAudioCompletionGuideTokens>(
161
+ "getAudioCompletionGuideTokens",
162
+ static_cast<napi_property_attributes>(napi_enumerable)),
163
+ InstanceMethod<&LlamaContext::DecodeAudioTokens>(
164
+ "decodeAudioTokens",
141
165
  static_cast<napi_property_attributes>(napi_enumerable))});
142
166
  Napi::FunctionReference *constructor = new Napi::FunctionReference();
143
167
  *constructor = Napi::Persistent(func);
@@ -148,19 +172,13 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
148
172
  }
149
173
 
150
174
  const std::vector<ggml_type> kv_cache_types = {
151
- GGML_TYPE_F32,
152
- GGML_TYPE_F16,
153
- GGML_TYPE_BF16,
154
- GGML_TYPE_Q8_0,
155
- GGML_TYPE_Q4_0,
156
- GGML_TYPE_Q4_1,
157
- GGML_TYPE_IQ4_NL,
158
- GGML_TYPE_Q5_0,
159
- GGML_TYPE_Q5_1,
175
+ GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
176
+ GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
177
+ GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
160
178
  };
161
179
 
162
- static ggml_type kv_cache_type_from_str(const std::string & s) {
163
- for (const auto & type : kv_cache_types) {
180
+ static ggml_type kv_cache_type_from_str(const std::string &s) {
181
+ for (const auto &type : kv_cache_types) {
164
182
  if (ggml_type_name(type) == s) {
165
183
  return type;
166
184
  }
@@ -168,12 +186,17 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
168
186
  throw std::runtime_error("Unsupported cache type: " + s);
169
187
  }
170
188
 
171
- static int32_t pooling_type_from_str(const std::string & s) {
172
- if (s == "none") return LLAMA_POOLING_TYPE_NONE;
173
- if (s == "mean") return LLAMA_POOLING_TYPE_MEAN;
174
- if (s == "cls") return LLAMA_POOLING_TYPE_CLS;
175
- if (s == "last") return LLAMA_POOLING_TYPE_LAST;
176
- if (s == "rank") return LLAMA_POOLING_TYPE_RANK;
189
+ static int32_t pooling_type_from_str(const std::string &s) {
190
+ if (s == "none")
191
+ return LLAMA_POOLING_TYPE_NONE;
192
+ if (s == "mean")
193
+ return LLAMA_POOLING_TYPE_MEAN;
194
+ if (s == "cls")
195
+ return LLAMA_POOLING_TYPE_CLS;
196
+ if (s == "last")
197
+ return LLAMA_POOLING_TYPE_LAST;
198
+ if (s == "rank")
199
+ return LLAMA_POOLING_TYPE_RANK;
177
200
  return LLAMA_POOLING_TYPE_UNSPECIFIED;
178
201
  }
179
202
 
@@ -200,7 +223,8 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
200
223
 
201
224
  params.chat_template = get_option<std::string>(options, "chat_template", "");
202
225
 
203
- std::string reasoning_format = get_option<std::string>(options, "reasoning_format", "none");
226
+ std::string reasoning_format =
227
+ get_option<std::string>(options, "reasoning_format", "none");
204
228
  if (reasoning_format == "deepseek") {
205
229
  params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
206
230
  } else {
@@ -216,16 +240,17 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
216
240
  params.n_ubatch = params.n_batch;
217
241
  }
218
242
  params.embd_normalize = get_option<int32_t>(options, "embd_normalize", 2);
219
- params.pooling_type = (enum llama_pooling_type) pooling_type_from_str(
220
- get_option<std::string>(options, "pooling_type", "").c_str()
221
- );
243
+ params.pooling_type = (enum llama_pooling_type)pooling_type_from_str(
244
+ get_option<std::string>(options, "pooling_type", "").c_str());
222
245
 
223
246
  params.cpuparams.n_threads =
224
247
  get_option<int32_t>(options, "n_threads", cpu_get_num_math() / 2);
225
248
  params.n_gpu_layers = get_option<int32_t>(options, "n_gpu_layers", -1);
226
249
  params.flash_attn = get_option<bool>(options, "flash_attn", false);
227
- params.cache_type_k = kv_cache_type_from_str(get_option<std::string>(options, "cache_type_k", "f16").c_str());
228
- params.cache_type_v = kv_cache_type_from_str(get_option<std::string>(options, "cache_type_v", "f16").c_str());
250
+ params.cache_type_k = kv_cache_type_from_str(
251
+ get_option<std::string>(options, "cache_type_k", "f16").c_str());
252
+ params.cache_type_v = kv_cache_type_from_str(
253
+ get_option<std::string>(options, "cache_type_v", "f16").c_str());
229
254
  params.ctx_shift = get_option<bool>(options, "ctx_shift", true);
230
255
 
231
256
  params.use_mlock = get_option<bool>(options, "use_mlock", false);
@@ -296,8 +321,9 @@ Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) {
296
321
  return Napi::String::New(info.Env(), _info);
297
322
  }
298
323
 
299
- bool validateModelChatTemplate(const struct llama_model * model, const bool use_jinja, const char * name) {
300
- const char * tmpl = llama_model_chat_template(model, name);
324
+ bool validateModelChatTemplate(const struct llama_model *model,
325
+ const bool use_jinja, const char *name) {
326
+ const char *tmpl = llama_model_chat_template(model, name);
301
327
  if (tmpl == nullptr) {
302
328
  return false;
303
329
  }
@@ -323,68 +349,68 @@ extern "C" void cleanup_logging();
323
349
  void LlamaContext::ToggleNativeLog(const Napi::CallbackInfo &info) {
324
350
  Napi::Env env = info.Env();
325
351
  bool enable = info[0].ToBoolean().Value();
326
-
352
+
327
353
  if (enable) {
328
354
  if (!info[1].IsFunction()) {
329
- Napi::TypeError::New(env, "Callback function required").ThrowAsJavaScriptException();
355
+ Napi::TypeError::New(env, "Callback function required")
356
+ .ThrowAsJavaScriptException();
330
357
  return;
331
358
  }
332
-
359
+
333
360
  // First clean up existing thread-safe function if any
334
361
  if (g_logging_enabled) {
335
362
  g_tsfn.Release();
336
363
  g_logging_enabled = false;
337
364
  }
338
-
365
+
339
366
  // Create thread-safe function that can be called from any thread
340
- g_tsfn = Napi::ThreadSafeFunction::New(
341
- env,
342
- info[1].As<Napi::Function>(),
343
- "LLAMA Logger",
344
- 0,
345
- 1,
346
- [](Napi::Env) {
347
- // Finalizer callback - nothing needed here
348
- }
349
- );
367
+ g_tsfn = Napi::ThreadSafeFunction::New(env, info[1].As<Napi::Function>(),
368
+ "LLAMA Logger", 0, 1, [](Napi::Env) {
369
+ // Finalizer callback - nothing
370
+ // needed here
371
+ });
350
372
 
351
373
  g_logging_enabled = true;
352
-
374
+
353
375
  // Set up log callback
354
- llama_log_set([](ggml_log_level level, const char* text, void* user_data) {
355
- // First call the default logger
356
- llama_log_callback_default(level, text, user_data);
357
-
358
- if (!g_logging_enabled) return;
359
-
360
- // Determine log level string
361
- std::string level_str = "";
362
- if (level == GGML_LOG_LEVEL_ERROR) {
363
- level_str = "error";
364
- } else if (level == GGML_LOG_LEVEL_INFO) {
365
- level_str = "info";
366
- } else if (level == GGML_LOG_LEVEL_WARN) {
367
- level_str = "warn";
368
- }
369
-
370
- // Create a heap-allocated copy of the data
371
- auto* data = new LogMessage{level_str, text};
372
-
373
- // Queue callback to be executed on the JavaScript thread
374
- auto status = g_tsfn.BlockingCall(data, [](Napi::Env env, Napi::Function jsCallback, LogMessage* data) {
375
- // This code runs on the JavaScript thread
376
- jsCallback.Call({
377
- Napi::String::New(env, data->level),
378
- Napi::String::New(env, data->text)
379
- });
380
- delete data;
381
- });
382
-
383
- // If the call failed (e.g., runtime is shutting down), clean up the data
384
- if (status != napi_ok) {
385
- delete data;
386
- }
387
- }, nullptr);
376
+ llama_log_set(
377
+ [](ggml_log_level level, const char *text, void *user_data) {
378
+ // First call the default logger
379
+ llama_log_callback_default(level, text, user_data);
380
+
381
+ if (!g_logging_enabled)
382
+ return;
383
+
384
+ // Determine log level string
385
+ std::string level_str = "";
386
+ if (level == GGML_LOG_LEVEL_ERROR) {
387
+ level_str = "error";
388
+ } else if (level == GGML_LOG_LEVEL_INFO) {
389
+ level_str = "info";
390
+ } else if (level == GGML_LOG_LEVEL_WARN) {
391
+ level_str = "warn";
392
+ }
393
+
394
+ // Create a heap-allocated copy of the data
395
+ auto *data = new LogMessage{level_str, text};
396
+
397
+ // Queue callback to be executed on the JavaScript thread
398
+ auto status = g_tsfn.BlockingCall(
399
+ data,
400
+ [](Napi::Env env, Napi::Function jsCallback, LogMessage *data) {
401
+ // This code runs on the JavaScript thread
402
+ jsCallback.Call({Napi::String::New(env, data->level),
403
+ Napi::String::New(env, data->text)});
404
+ delete data;
405
+ });
406
+
407
+ // If the call failed (e.g., runtime is shutting down), clean up the
408
+ // data
409
+ if (status != napi_ok) {
410
+ delete data;
411
+ }
412
+ },
413
+ nullptr);
388
414
  } else {
389
415
  // Disable logging
390
416
  if (g_logging_enabled) {
@@ -422,22 +448,47 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
422
448
  Napi::Object minja = Napi::Object::New(info.Env());
423
449
  minja.Set("default", validateModelChatTemplate(model, true, ""));
424
450
  Napi::Object defaultCaps = Napi::Object::New(info.Env());
425
- defaultCaps.Set("tools", _templates.get()->template_default->original_caps().supports_tools);
426
- defaultCaps.Set("toolCalls", _templates.get()->template_default->original_caps().supports_tool_calls);
427
- defaultCaps.Set("toolResponses", _templates.get()->template_default->original_caps().supports_tool_responses);
428
- defaultCaps.Set("systemRole", _templates.get()->template_default->original_caps().supports_system_role);
429
- defaultCaps.Set("parallelToolCalls", _templates.get()->template_default->original_caps().supports_parallel_tool_calls);
430
- defaultCaps.Set("toolCallId", _templates.get()->template_default->original_caps().supports_tool_call_id);
451
+ defaultCaps.Set(
452
+ "tools",
453
+ _templates.get()->template_default->original_caps().supports_tools);
454
+ defaultCaps.Set(
455
+ "toolCalls",
456
+ _templates.get()->template_default->original_caps().supports_tool_calls);
457
+ defaultCaps.Set("toolResponses", _templates.get()
458
+ ->template_default->original_caps()
459
+ .supports_tool_responses);
460
+ defaultCaps.Set(
461
+ "systemRole",
462
+ _templates.get()->template_default->original_caps().supports_system_role);
463
+ defaultCaps.Set("parallelToolCalls", _templates.get()
464
+ ->template_default->original_caps()
465
+ .supports_parallel_tool_calls);
466
+ defaultCaps.Set("toolCallId", _templates.get()
467
+ ->template_default->original_caps()
468
+ .supports_tool_call_id);
431
469
  minja.Set("defaultCaps", defaultCaps);
432
470
  minja.Set("toolUse", validateModelChatTemplate(model, true, "tool_use"));
433
471
  if (_templates.get()->template_tool_use) {
434
472
  Napi::Object toolUseCaps = Napi::Object::New(info.Env());
435
- toolUseCaps.Set("tools", _templates.get()->template_tool_use->original_caps().supports_tools);
436
- toolUseCaps.Set("toolCalls", _templates.get()->template_tool_use->original_caps().supports_tool_calls);
437
- toolUseCaps.Set("toolResponses", _templates.get()->template_tool_use->original_caps().supports_tool_responses);
438
- toolUseCaps.Set("systemRole", _templates.get()->template_tool_use->original_caps().supports_system_role);
439
- toolUseCaps.Set("parallelToolCalls", _templates.get()->template_tool_use->original_caps().supports_parallel_tool_calls);
440
- toolUseCaps.Set("toolCallId", _templates.get()->template_tool_use->original_caps().supports_tool_call_id);
473
+ toolUseCaps.Set(
474
+ "tools",
475
+ _templates.get()->template_tool_use->original_caps().supports_tools);
476
+ toolUseCaps.Set("toolCalls", _templates.get()
477
+ ->template_tool_use->original_caps()
478
+ .supports_tool_calls);
479
+ toolUseCaps.Set("toolResponses", _templates.get()
480
+ ->template_tool_use->original_caps()
481
+ .supports_tool_responses);
482
+ toolUseCaps.Set("systemRole", _templates.get()
483
+ ->template_tool_use->original_caps()
484
+ .supports_system_role);
485
+ toolUseCaps.Set("parallelToolCalls",
486
+ _templates.get()
487
+ ->template_tool_use->original_caps()
488
+ .supports_parallel_tool_calls);
489
+ toolUseCaps.Set("toolCallId", _templates.get()
490
+ ->template_tool_use->original_caps()
491
+ .supports_tool_call_id);
441
492
  minja.Set("toolUseCaps", toolUseCaps);
442
493
  }
443
494
  chatTemplates.Set("minja", minja);
@@ -446,20 +497,17 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
446
497
  details.Set("metadata", metadata);
447
498
 
448
499
  // Deprecated: use chatTemplates.llamaChat instead
449
- details.Set("isChatTemplateSupported", validateModelChatTemplate(_sess->model(), false, ""));
500
+ details.Set("isChatTemplateSupported",
501
+ validateModelChatTemplate(_sess->model(), false, ""));
450
502
  return details;
451
503
  }
452
504
 
453
505
  common_chat_params getFormattedChatWithJinja(
454
- const std::shared_ptr<LlamaSession> &sess,
455
- const common_chat_templates_ptr &templates,
456
- const std::string &messages,
457
- const std::string &chat_template,
458
- const std::string &json_schema,
459
- const std::string &tools,
460
- const bool &parallel_tool_calls,
461
- const std::string &tool_choice
462
- ) {
506
+ const std::shared_ptr<LlamaSession> &sess,
507
+ const common_chat_templates_ptr &templates, const std::string &messages,
508
+ const std::string &chat_template, const std::string &json_schema,
509
+ const std::string &tools, const bool &parallel_tool_calls,
510
+ const std::string &tool_choice) {
463
511
  common_chat_templates_inputs inputs;
464
512
  inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
465
513
  auto useTools = !tools.empty();
@@ -473,23 +521,22 @@ common_chat_params getFormattedChatWithJinja(
473
521
  if (!json_schema.empty()) {
474
522
  inputs.json_schema = json::parse(json_schema);
475
523
  }
476
- inputs.extract_reasoning = sess->params().reasoning_format != COMMON_REASONING_FORMAT_NONE;
524
+ inputs.extract_reasoning =
525
+ sess->params().reasoning_format != COMMON_REASONING_FORMAT_NONE;
477
526
 
478
527
  // If chat_template is provided, create new one and use it (probably slow)
479
528
  if (!chat_template.empty()) {
480
- auto tmps = common_chat_templates_init(sess->model(), chat_template);
481
- return common_chat_templates_apply(tmps.get(), inputs);
529
+ auto tmps = common_chat_templates_init(sess->model(), chat_template);
530
+ return common_chat_templates_apply(tmps.get(), inputs);
482
531
  } else {
483
- return common_chat_templates_apply(templates.get(), inputs);
532
+ return common_chat_templates_apply(templates.get(), inputs);
484
533
  }
485
534
  }
486
535
 
487
- std::string getFormattedChat(
488
- const struct llama_model * model,
489
- const common_chat_templates_ptr &templates,
490
- const std::string &messages,
491
- const std::string &chat_template
492
- ) {
536
+ std::string getFormattedChat(const struct llama_model *model,
537
+ const common_chat_templates_ptr &templates,
538
+ const std::string &messages,
539
+ const std::string &chat_template) {
493
540
  common_chat_templates_inputs inputs;
494
541
  inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
495
542
  inputs.use_jinja = false;
@@ -506,7 +553,8 @@ std::string getFormattedChat(
506
553
  // getFormattedChat(
507
554
  // messages: [{ role: string, content: string }],
508
555
  // chat_template: string,
509
- // params: { jinja: boolean, json_schema: string, tools: string, parallel_tool_calls: boolean, tool_choice: string }
556
+ // params: { jinja: boolean, json_schema: string, tools: string,
557
+ // parallel_tool_calls: boolean, tool_choice: string }
510
558
  // ): object | string
511
559
  Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
512
560
  Napi::Env env = info.Env();
@@ -517,32 +565,42 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
517
565
  auto chat_template = info[1].IsString() ? info[1].ToString().Utf8Value() : "";
518
566
 
519
567
  auto has_params = info.Length() >= 2;
520
- auto params = has_params ? info[2].As<Napi::Object>() : Napi::Object::New(env);
568
+ auto params =
569
+ has_params ? info[2].As<Napi::Object>() : Napi::Object::New(env);
521
570
 
522
571
  if (get_option<bool>(params, "jinja", false)) {
523
572
  std::string json_schema_str = "";
524
573
  if (!is_nil(params.Get("response_format"))) {
525
574
  auto response_format = params.Get("response_format").As<Napi::Object>();
526
- auto response_format_type = get_option<std::string>(response_format, "type", "text");
527
- if (response_format_type == "json_schema" && response_format.Has("json_schema")) {
528
- auto json_schema = response_format.Get("json_schema").As<Napi::Object>();
529
- json_schema_str = json_schema.Has("schema") ?
530
- json_stringify(json_schema.Get("schema").As<Napi::Object>()) :
531
- "{}";
575
+ auto response_format_type =
576
+ get_option<std::string>(response_format, "type", "text");
577
+ if (response_format_type == "json_schema" &&
578
+ response_format.Has("json_schema")) {
579
+ auto json_schema =
580
+ response_format.Get("json_schema").As<Napi::Object>();
581
+ json_schema_str =
582
+ json_schema.Has("schema")
583
+ ? json_stringify(json_schema.Get("schema").As<Napi::Object>())
584
+ : "{}";
532
585
  } else if (response_format_type == "json_object") {
533
- json_schema_str = response_format.Has("schema") ?
534
- json_stringify(response_format.Get("schema").As<Napi::Object>()) :
535
- "{}";
586
+ json_schema_str =
587
+ response_format.Has("schema")
588
+ ? json_stringify(
589
+ response_format.Get("schema").As<Napi::Object>())
590
+ : "{}";
536
591
  }
537
592
  }
538
- auto tools_str = params.Has("tools") ?
539
- json_stringify(params.Get("tools").As<Napi::Array>()) :
540
- "";
541
- auto parallel_tool_calls = get_option<bool>(params, "parallel_tool_calls", false);
593
+ auto tools_str = params.Has("tools")
594
+ ? json_stringify(params.Get("tools").As<Napi::Array>())
595
+ : "";
596
+ auto parallel_tool_calls =
597
+ get_option<bool>(params, "parallel_tool_calls", false);
542
598
  auto tool_choice = get_option<std::string>(params, "tool_choice", "");
543
599
 
544
- auto chatParams = getFormattedChatWithJinja(_sess, _templates, messages, chat_template, json_schema_str, tools_str, parallel_tool_calls, tool_choice);
545
-
600
+ auto chatParams = getFormattedChatWithJinja(
601
+ _sess, _templates, messages, chat_template, json_schema_str, tools_str,
602
+ parallel_tool_calls, tool_choice);
603
+
546
604
  Napi::Object result = Napi::Object::New(env);
547
605
  result.Set("prompt", chatParams.prompt);
548
606
  // chat_format: int
@@ -554,30 +612,33 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
554
612
  // grammar_triggers: [{ value: string, token: number }]
555
613
  Napi::Array grammar_triggers = Napi::Array::New(env);
556
614
  for (size_t i = 0; i < chatParams.grammar_triggers.size(); i++) {
557
- const auto & trigger = chatParams.grammar_triggers[i];
558
- Napi::Object triggerObj = Napi::Object::New(env);
559
- triggerObj.Set("type", Napi::Number::New(env, trigger.type));
560
- triggerObj.Set("value", Napi::String::New(env, trigger.value));
561
- triggerObj.Set("token", Napi::Number::New(env, trigger.token));
562
- grammar_triggers.Set(i, triggerObj);
615
+ const auto &trigger = chatParams.grammar_triggers[i];
616
+ Napi::Object triggerObj = Napi::Object::New(env);
617
+ triggerObj.Set("type", Napi::Number::New(env, trigger.type));
618
+ triggerObj.Set("value", Napi::String::New(env, trigger.value));
619
+ triggerObj.Set("token", Napi::Number::New(env, trigger.token));
620
+ grammar_triggers.Set(i, triggerObj);
563
621
  }
564
622
  result.Set("grammar_triggers", grammar_triggers);
565
623
  // preserved_tokens: string[]
566
624
  Napi::Array preserved_tokens = Napi::Array::New(env);
567
625
  for (size_t i = 0; i < chatParams.preserved_tokens.size(); i++) {
568
- preserved_tokens.Set(i, Napi::String::New(env, chatParams.preserved_tokens[i].c_str()));
626
+ preserved_tokens.Set(
627
+ i, Napi::String::New(env, chatParams.preserved_tokens[i].c_str()));
569
628
  }
570
629
  result.Set("preserved_tokens", preserved_tokens);
571
630
  // additional_stops: string[]
572
631
  Napi::Array additional_stops = Napi::Array::New(env);
573
632
  for (size_t i = 0; i < chatParams.additional_stops.size(); i++) {
574
- additional_stops.Set(i, Napi::String::New(env, chatParams.additional_stops[i].c_str()));
633
+ additional_stops.Set(
634
+ i, Napi::String::New(env, chatParams.additional_stops[i].c_str()));
575
635
  }
576
636
  result.Set("additional_stops", additional_stops);
577
637
 
578
638
  return result;
579
639
  } else {
580
- auto formatted = getFormattedChat(_sess->model(), _templates, messages, chat_template);
640
+ auto formatted =
641
+ getFormattedChat(_sess->model(), _templates, messages, chat_template);
581
642
  return Napi::String::New(env, formatted);
582
643
  }
583
644
  }
@@ -625,7 +686,9 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
625
686
 
626
687
  // Check if multimodal is enabled when media_paths are provided
627
688
  if (!media_paths.empty() && !(_has_multimodal && _mtmd_ctx != nullptr)) {
628
- Napi::Error::New(env, "Multimodal support must be enabled via initMultimodal to use media_paths").ThrowAsJavaScriptException();
689
+ Napi::Error::New(env, "Multimodal support must be enabled via "
690
+ "initMultimodal to use media_paths")
691
+ .ThrowAsJavaScriptException();
629
692
  return env.Undefined();
630
693
  }
631
694
 
@@ -641,16 +704,20 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
641
704
  std::string json_schema_str = "";
642
705
  if (options.Has("response_format")) {
643
706
  auto response_format = options.Get("response_format").As<Napi::Object>();
644
- auto response_format_type = get_option<std::string>(response_format, "type", "text");
645
- if (response_format_type == "json_schema" && response_format.Has("json_schema")) {
707
+ auto response_format_type =
708
+ get_option<std::string>(response_format, "type", "text");
709
+ if (response_format_type == "json_schema" &&
710
+ response_format.Has("json_schema")) {
646
711
  auto json_schema = response_format.Get("json_schema").As<Napi::Object>();
647
- json_schema_str = json_schema.Has("schema") ?
648
- json_stringify(json_schema.Get("schema").As<Napi::Object>()) :
649
- "{}";
712
+ json_schema_str =
713
+ json_schema.Has("schema")
714
+ ? json_stringify(json_schema.Get("schema").As<Napi::Object>())
715
+ : "{}";
650
716
  } else if (response_format_type == "json_object") {
651
- json_schema_str = response_format.Has("schema") ?
652
- json_stringify(response_format.Get("schema").As<Napi::Object>()) :
653
- "{}";
717
+ json_schema_str =
718
+ response_format.Has("schema")
719
+ ? json_stringify(response_format.Get("schema").As<Napi::Object>())
720
+ : "{}";
654
721
  }
655
722
  }
656
723
 
@@ -659,7 +726,9 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
659
726
  auto preserved_tokens = options.Get("preserved_tokens").As<Napi::Array>();
660
727
  for (size_t i = 0; i < preserved_tokens.Length(); i++) {
661
728
  auto token = preserved_tokens.Get(i).ToString().Utf8Value();
662
- auto ids = common_tokenize(_sess->context(), token, /* add_special= */ false, /* parse_special= */ true);
729
+ auto ids =
730
+ common_tokenize(_sess->context(), token, /* add_special= */ false,
731
+ /* parse_special= */ true);
663
732
  if (ids.size() == 1) {
664
733
  params.sampling.preserved_tokens.insert(ids[0]);
665
734
  }
@@ -672,15 +741,22 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
672
741
  for (size_t i = 0; i < grammar_triggers.Length(); i++) {
673
742
  auto trigger_obj = grammar_triggers.Get(i).As<Napi::Object>();
674
743
 
675
- auto type = static_cast<common_grammar_trigger_type>(trigger_obj.Get("type").ToNumber().Int32Value());
744
+ auto type = static_cast<common_grammar_trigger_type>(
745
+ trigger_obj.Get("type").ToNumber().Int32Value());
676
746
  auto word = trigger_obj.Get("value").ToString().Utf8Value();
677
747
 
678
748
  if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
679
- auto ids = common_tokenize(_sess->context(), word, /* add_special= */ false, /* parse_special= */ true);
749
+ auto ids =
750
+ common_tokenize(_sess->context(), word, /* add_special= */ false,
751
+ /* parse_special= */ true);
680
752
  if (ids.size() == 1) {
681
753
  auto token = ids[0];
682
- if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
683
- throw std::runtime_error("Grammar trigger word should be marked as preserved token");
754
+ if (std::find(params.sampling.preserved_tokens.begin(),
755
+ params.sampling.preserved_tokens.end(),
756
+ (llama_token)token) ==
757
+ params.sampling.preserved_tokens.end()) {
758
+ throw std::runtime_error(
759
+ "Grammar trigger word should be marked as preserved token");
684
760
  }
685
761
  common_grammar_trigger trigger;
686
762
  trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
@@ -688,14 +764,16 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
688
764
  trigger.token = token;
689
765
  params.sampling.grammar_triggers.push_back(std::move(trigger));
690
766
  } else {
691
- params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
767
+ params.sampling.grammar_triggers.push_back(
768
+ {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
692
769
  }
693
770
  } else {
694
771
  common_grammar_trigger trigger;
695
772
  trigger.type = type;
696
773
  trigger.value = word;
697
774
  if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
698
- auto token = (llama_token) trigger_obj.Get("token").ToNumber().Int32Value();
775
+ auto token =
776
+ (llama_token)trigger_obj.Get("token").ToNumber().Int32Value();
699
777
  trigger.token = token;
700
778
  }
701
779
  params.sampling.grammar_triggers.push_back(std::move(trigger));
@@ -705,7 +783,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
705
783
 
706
784
  // Handle grammar_lazy from options
707
785
  if (options.Has("grammar_lazy")) {
708
- params.sampling.grammar_lazy = options.Get("grammar_lazy").ToBoolean().Value();
786
+ params.sampling.grammar_lazy =
787
+ options.Get("grammar_lazy").ToBoolean().Value();
709
788
  }
710
789
 
711
790
  if (options.Has("messages") && options.Get("messages").IsArray()) {
@@ -713,29 +792,27 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
713
792
  auto chat_template = get_option<std::string>(options, "chat_template", "");
714
793
  auto jinja = get_option<bool>(options, "jinja", false);
715
794
  if (jinja) {
716
- auto tools_str = options.Has("tools") ?
717
- json_stringify(options.Get("tools").As<Napi::Array>()) :
718
- "";
719
- auto parallel_tool_calls = get_option<bool>(options, "parallel_tool_calls", false);
720
- auto tool_choice = get_option<std::string>(options, "tool_choice", "none");
795
+ auto tools_str =
796
+ options.Has("tools")
797
+ ? json_stringify(options.Get("tools").As<Napi::Array>())
798
+ : "";
799
+ auto parallel_tool_calls =
800
+ get_option<bool>(options, "parallel_tool_calls", false);
801
+ auto tool_choice =
802
+ get_option<std::string>(options, "tool_choice", "none");
721
803
 
722
804
  auto chatParams = getFormattedChatWithJinja(
723
- _sess,
724
- _templates,
725
- json_stringify(messages),
726
- chat_template,
727
- json_schema_str,
728
- tools_str,
729
- parallel_tool_calls,
730
- tool_choice
731
- );
732
-
805
+ _sess, _templates, json_stringify(messages), chat_template,
806
+ json_schema_str, tools_str, parallel_tool_calls, tool_choice);
807
+
733
808
  params.prompt = chatParams.prompt;
734
809
 
735
810
  chat_format = chatParams.format;
736
811
 
737
- for (const auto & token : chatParams.preserved_tokens) {
738
- auto ids = common_tokenize(_sess->context(), token, /* add_special= */ false, /* parse_special= */ true);
812
+ for (const auto &token : chatParams.preserved_tokens) {
813
+ auto ids =
814
+ common_tokenize(_sess->context(), token, /* add_special= */ false,
815
+ /* parse_special= */ true);
739
816
  if (ids.size() == 1) {
740
817
  params.sampling.preserved_tokens.insert(ids[0]);
741
818
  }
@@ -745,22 +822,18 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
745
822
  // grammar param always wins jinja template & json_schema
746
823
  params.sampling.grammar = chatParams.grammar;
747
824
  params.sampling.grammar_lazy = chatParams.grammar_lazy;
748
- for (const auto & trigger : chatParams.grammar_triggers) {
825
+ for (const auto &trigger : chatParams.grammar_triggers) {
749
826
  params.sampling.grammar_triggers.push_back(trigger);
750
827
  }
751
828
  has_grammar_set = true;
752
829
  }
753
-
754
- for (const auto & stop : chatParams.additional_stops) {
830
+
831
+ for (const auto &stop : chatParams.additional_stops) {
755
832
  stop_words.push_back(stop);
756
833
  }
757
834
  } else {
758
835
  auto formatted = getFormattedChat(
759
- _sess->model(),
760
- _templates,
761
- json_stringify(messages),
762
- chat_template
763
- );
836
+ _sess->model(), _templates, json_stringify(messages), chat_template);
764
837
  params.prompt = formatted;
765
838
  }
766
839
  } else {
@@ -772,7 +845,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
772
845
  }
773
846
 
774
847
  if (!has_grammar_set && !json_schema_str.empty()) {
775
- params.sampling.grammar = json_schema_to_grammar(json::parse(json_schema_str));
848
+ params.sampling.grammar =
849
+ json_schema_to_grammar(json::parse(json_schema_str));
776
850
  }
777
851
 
778
852
  params.n_predict = get_option<int32_t>(options, "n_predict", -1);
@@ -794,16 +868,32 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
794
868
  params.sampling.penalty_present =
795
869
  get_option<float>(options, "penalty_present", 0.00f);
796
870
  params.sampling.typ_p = get_option<float>(options, "typical_p", 1.00f);
797
- params.sampling.xtc_threshold = get_option<float>(options, "xtc_threshold", 0.00f);
798
- params.sampling.xtc_probability = get_option<float>(options, "xtc_probability", 0.10f);
799
- params.sampling.dry_multiplier = get_option<float>(options, "dry_multiplier", 1.75f);
871
+ params.sampling.xtc_threshold =
872
+ get_option<float>(options, "xtc_threshold", 0.00f);
873
+ params.sampling.xtc_probability =
874
+ get_option<float>(options, "xtc_probability", 0.10f);
875
+ params.sampling.dry_multiplier =
876
+ get_option<float>(options, "dry_multiplier", 1.75f);
800
877
  params.sampling.dry_base = get_option<float>(options, "dry_base", 2);
801
- params.sampling.dry_allowed_length = get_option<float>(options, "dry_allowed_length", -1);
802
- params.sampling.dry_penalty_last_n = get_option<float>(options, "dry_penalty_last_n", 0);
803
- params.sampling.top_n_sigma = get_option<float>(options, "top_n_sigma", -1.0f);
878
+ params.sampling.dry_allowed_length =
879
+ get_option<float>(options, "dry_allowed_length", -1);
880
+ params.sampling.dry_penalty_last_n =
881
+ get_option<float>(options, "dry_penalty_last_n", 0);
882
+ params.sampling.top_n_sigma =
883
+ get_option<float>(options, "top_n_sigma", -1.0f);
804
884
  params.sampling.ignore_eos = get_option<bool>(options, "ignore_eos", false);
805
885
  params.n_keep = get_option<int32_t>(options, "n_keep", 0);
806
- params.sampling.seed = get_option<int32_t>(options, "seed", LLAMA_DEFAULT_SEED);
886
+ params.sampling.seed =
887
+ get_option<int32_t>(options, "seed", LLAMA_DEFAULT_SEED);
888
+
889
+ // guide_tokens
890
+ std::vector<llama_token> guide_tokens;
891
+ if (options.Has("guide_tokens")) {
892
+ auto guide_tokens_array = options.Get("guide_tokens").As<Napi::Array>();
893
+ for (size_t i = 0; i < guide_tokens_array.Length(); i++) {
894
+ guide_tokens.push_back(guide_tokens_array.Get(i).ToNumber().Int32Value());
895
+ }
896
+ }
807
897
 
808
898
  Napi::Function callback;
809
899
  if (info.Length() >= 2) {
@@ -811,7 +901,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
811
901
  }
812
902
 
813
903
  auto *worker =
814
- new LlamaCompletionWorker(info, _sess, callback, params, stop_words, chat_format, media_paths);
904
+ new LlamaCompletionWorker(info, _sess, callback, params, stop_words,
905
+ chat_format, media_paths, guide_tokens);
815
906
  worker->Queue();
816
907
  _wip = worker;
817
908
  worker->OnComplete([this]() { _wip = nullptr; });
@@ -965,7 +1056,8 @@ void LlamaContext::RemoveLoraAdapters(const Napi::CallbackInfo &info) {
965
1056
 
966
1057
  // getLoadedLoraAdapters(): Promise<{ count, lora_adapters: [{ path: string,
967
1058
  // scaled: number }] }>
968
- Napi::Value LlamaContext::GetLoadedLoraAdapters(const Napi::CallbackInfo &info) {
1059
+ Napi::Value
1060
+ LlamaContext::GetLoadedLoraAdapters(const Napi::CallbackInfo &info) {
969
1061
  Napi::Env env = info.Env();
970
1062
  Napi::Array lora_adapters = Napi::Array::New(env, _lora.size());
971
1063
  for (size_t i = 0; i < _lora.size(); i++) {
@@ -983,18 +1075,18 @@ Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) {
983
1075
  if (_wip != nullptr) {
984
1076
  _wip->SetStop();
985
1077
  }
986
-
1078
+
987
1079
  if (_sess == nullptr) {
988
1080
  auto promise = Napi::Promise::Deferred(env);
989
1081
  promise.Resolve(env.Undefined());
990
1082
  return promise.Promise();
991
1083
  }
992
-
1084
+
993
1085
  // Clear the mtmd context reference in the session
994
1086
  if (_mtmd_ctx != nullptr) {
995
1087
  _sess->set_mtmd_ctx(nullptr);
996
1088
  }
997
-
1089
+
998
1090
  auto *worker = new DisposeWorker(info, std::move(_sess));
999
1091
  worker->Queue();
1000
1092
  return worker->Promise();
@@ -1022,7 +1114,8 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1022
1114
  Napi::Env env = info.Env();
1023
1115
 
1024
1116
  if (info.Length() < 1 || !info[0].IsObject()) {
1025
- Napi::TypeError::New(env, "Object expected for mmproj path").ThrowAsJavaScriptException();
1117
+ Napi::TypeError::New(env, "Object expected for mmproj path")
1118
+ .ThrowAsJavaScriptException();
1026
1119
  }
1027
1120
 
1028
1121
  auto options = info[0].As<Napi::Object>();
@@ -1030,7 +1123,8 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1030
1123
  auto use_gpu = options.Get("use_gpu").ToBoolean().Value();
1031
1124
 
1032
1125
  if (mmproj_path.empty()) {
1033
- Napi::TypeError::New(env, "mmproj path is required").ThrowAsJavaScriptException();
1126
+ Napi::TypeError::New(env, "mmproj path is required")
1127
+ .ThrowAsJavaScriptException();
1034
1128
  }
1035
1129
 
1036
1130
  console_log(env, "Initializing multimodal with mmproj path: " + mmproj_path);
@@ -1055,48 +1149,55 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1055
1149
  mtmd_params.n_threads = _sess->params().cpuparams.n_threads;
1056
1150
  mtmd_params.verbosity = (ggml_log_level)GGML_LOG_LEVEL_INFO;
1057
1151
 
1058
- console_log(env, format_string("Initializing mtmd context with threads=%d, use_gpu=%d",
1059
- mtmd_params.n_threads, mtmd_params.use_gpu ? 1 : 0));
1152
+ console_log(env, format_string(
1153
+ "Initializing mtmd context with threads=%d, use_gpu=%d",
1154
+ mtmd_params.n_threads, mtmd_params.use_gpu ? 1 : 0));
1060
1155
 
1061
1156
  _mtmd_ctx = mtmd_init_from_file(mmproj_path.c_str(), model, mtmd_params);
1062
1157
  if (_mtmd_ctx == nullptr) {
1063
- Napi::Error::New(env, "Failed to initialize multimodal context").ThrowAsJavaScriptException();
1158
+ Napi::Error::New(env, "Failed to initialize multimodal context")
1159
+ .ThrowAsJavaScriptException();
1064
1160
  return Napi::Boolean::New(env, false);
1065
1161
  }
1066
1162
 
1067
1163
  _has_multimodal = true;
1068
-
1164
+
1069
1165
  // Share the mtmd context with the session
1070
1166
  _sess->set_mtmd_ctx(_mtmd_ctx);
1071
1167
 
1072
1168
  // Check if the model uses M-RoPE or non-causal attention
1073
1169
  bool uses_mrope = mtmd_decode_use_mrope(_mtmd_ctx);
1074
1170
  bool uses_non_causal = mtmd_decode_use_non_causal(_mtmd_ctx);
1075
- console_log(env, format_string("Model multimodal properties: uses_mrope=%d, uses_non_causal=%d",
1076
- uses_mrope ? 1 : 0, uses_non_causal ? 1 : 0));
1171
+ console_log(
1172
+ env, format_string(
1173
+ "Model multimodal properties: uses_mrope=%d, uses_non_causal=%d",
1174
+ uses_mrope ? 1 : 0, uses_non_causal ? 1 : 0));
1077
1175
 
1078
- console_log(env, "Multimodal context initialized successfully with mmproj: " + mmproj_path);
1176
+ console_log(env, "Multimodal context initialized successfully with mmproj: " +
1177
+ mmproj_path);
1079
1178
  return Napi::Boolean::New(env, true);
1080
1179
  }
1081
1180
 
1082
1181
  // isMultimodalEnabled(): boolean
1083
1182
  Napi::Value LlamaContext::IsMultimodalEnabled(const Napi::CallbackInfo &info) {
1084
- return Napi::Boolean::New(info.Env(), _has_multimodal && _mtmd_ctx != nullptr);
1183
+ return Napi::Boolean::New(info.Env(),
1184
+ _has_multimodal && _mtmd_ctx != nullptr);
1085
1185
  }
1086
1186
 
1087
1187
  // getMultimodalSupport(): Promise<{ vision: boolean, audio: boolean }>
1088
1188
  Napi::Value LlamaContext::GetMultimodalSupport(const Napi::CallbackInfo &info) {
1089
1189
  Napi::Env env = info.Env();
1090
1190
  auto result = Napi::Object::New(env);
1091
-
1191
+
1092
1192
  if (_has_multimodal && _mtmd_ctx != nullptr) {
1093
- result.Set("vision", Napi::Boolean::New(env, mtmd_support_vision(_mtmd_ctx)));
1193
+ result.Set("vision",
1194
+ Napi::Boolean::New(env, mtmd_support_vision(_mtmd_ctx)));
1094
1195
  result.Set("audio", Napi::Boolean::New(env, mtmd_support_audio(_mtmd_ctx)));
1095
1196
  } else {
1096
1197
  result.Set("vision", Napi::Boolean::New(env, false));
1097
1198
  result.Set("audio", Napi::Boolean::New(env, false));
1098
1199
  }
1099
-
1200
+
1100
1201
  return result;
1101
1202
  }
1102
1203
 
@@ -1107,10 +1208,206 @@ void LlamaContext::ReleaseMultimodal(const Napi::CallbackInfo &info) {
1107
1208
  if (_sess != nullptr) {
1108
1209
  _sess->set_mtmd_ctx(nullptr);
1109
1210
  }
1110
-
1211
+
1111
1212
  // Free the mtmd context
1112
1213
  mtmd_free(_mtmd_ctx);
1113
1214
  _mtmd_ctx = nullptr;
1114
1215
  _has_multimodal = false;
1115
1216
  }
1116
1217
  }
1218
+
1219
+ tts_type LlamaContext::getTTSType(Napi::Env env, nlohmann::json speaker) {
1220
+ if (speaker.is_object() && speaker.contains("version")) {
1221
+ std::string version = speaker["version"].get<std::string>();
1222
+ if (version == "0.2") {
1223
+ return OUTETTS_V0_2;
1224
+ } else if (version == "0.3") {
1225
+ return OUTETTS_V0_3;
1226
+ } else {
1227
+ Napi::Error::New(env, format_string("Unsupported speaker version '%s'\n",
1228
+ version.c_str()))
1229
+ .ThrowAsJavaScriptException();
1230
+ return UNKNOWN;
1231
+ }
1232
+ }
1233
+ if (_tts_type != UNKNOWN) {
1234
+ return _tts_type;
1235
+ }
1236
+ const char *chat_template =
1237
+ llama_model_chat_template(_sess->model(), nullptr);
1238
+ if (chat_template && std::string(chat_template) == "outetts-0.3") {
1239
+ return OUTETTS_V0_3;
1240
+ }
1241
+ return OUTETTS_V0_2;
1242
+ }
1243
+
1244
+ // initVocoder(path: string): boolean
1245
+ Napi::Value LlamaContext::InitVocoder(const Napi::CallbackInfo &info) {
1246
+ Napi::Env env = info.Env();
1247
+ if (info.Length() < 1 || !info[0].IsString()) {
1248
+ Napi::TypeError::New(env, "String expected for vocoder path")
1249
+ .ThrowAsJavaScriptException();
1250
+ }
1251
+ auto vocoder_path = info[0].ToString().Utf8Value();
1252
+ if (vocoder_path.empty()) {
1253
+ Napi::TypeError::New(env, "vocoder path is required")
1254
+ .ThrowAsJavaScriptException();
1255
+ }
1256
+ if (_has_vocoder) {
1257
+ Napi::Error::New(env, "Vocoder already initialized")
1258
+ .ThrowAsJavaScriptException();
1259
+ return Napi::Boolean::New(env, false);
1260
+ }
1261
+ _tts_type = getTTSType(env);
1262
+ _vocoder.params = _sess->params();
1263
+ _vocoder.params.warmup = false;
1264
+ _vocoder.params.model.path = vocoder_path;
1265
+ _vocoder.params.embedding = true;
1266
+ _vocoder.params.ctx_shift = false;
1267
+ _vocoder.params.n_ubatch = _vocoder.params.n_batch;
1268
+ common_init_result result = common_init_from_params(_vocoder.params);
1269
+ if (result.model == nullptr || result.context == nullptr) {
1270
+ Napi::Error::New(env, "Failed to initialize vocoder")
1271
+ .ThrowAsJavaScriptException();
1272
+ return Napi::Boolean::New(env, false);
1273
+ }
1274
+ _vocoder.model = std::move(result.model);
1275
+ _vocoder.context = std::move(result.context);
1276
+ _has_vocoder = true;
1277
+ return Napi::Boolean::New(env, true);
1278
+ }
1279
+
1280
+ // releaseVocoder(): void
1281
+ void LlamaContext::ReleaseVocoder(const Napi::CallbackInfo &info) {
1282
+ if (_has_vocoder) {
1283
+ _vocoder.model.reset();
1284
+ _vocoder.context.reset();
1285
+ _has_vocoder = false;
1286
+ }
1287
+ }
1288
+
1289
+ // isVocoderEnabled(): boolean
1290
+ Napi::Value LlamaContext::IsVocoderEnabled(const Napi::CallbackInfo &info) {
1291
+ Napi::Env env = info.Env();
1292
+ return Napi::Boolean::New(env, _has_vocoder);
1293
+ }
1294
+
1295
+ // getFormattedAudioCompletion(speaker: string|null, text: string): string
1296
+ Napi::Value
1297
+ LlamaContext::GetFormattedAudioCompletion(const Napi::CallbackInfo &info) {
1298
+ Napi::Env env = info.Env();
1299
+ if (info.Length() < 2 || !info[1].IsString()) {
1300
+ Napi::TypeError::New(env, "text parameter is required for audio completion")
1301
+ .ThrowAsJavaScriptException();
1302
+ }
1303
+ auto text = info[1].ToString().Utf8Value();
1304
+ auto speaker_json = info[0].IsString() ? info[0].ToString().Utf8Value() : "";
1305
+ nlohmann::json speaker =
1306
+ speaker_json.empty() ? nullptr : nlohmann::json::parse(speaker_json);
1307
+ const tts_type type = getTTSType(env, speaker);
1308
+ std::string audio_text = DEFAULT_AUDIO_TEXT;
1309
+ std::string audio_data = DEFAULT_AUDIO_DATA;
1310
+ if (type == OUTETTS_V0_3) {
1311
+ audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"),
1312
+ "<|space|>");
1313
+ audio_data =
1314
+ std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
1315
+ audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"),
1316
+ "<|space|>");
1317
+ }
1318
+ if (!speaker_json.empty()) {
1319
+ audio_text = audio_text_from_speaker(speaker, type);
1320
+ audio_data = audio_data_from_speaker(speaker, type);
1321
+ }
1322
+ return Napi::String::New(env, "<|im_start|>\n" + audio_text +
1323
+ process_text(text, type) +
1324
+ "<|text_end|>\n" + audio_data + "\n");
1325
+ }
1326
+
1327
+ // getAudioCompletionGuideTokens(text: string): Int32Array
1328
+ Napi::Value
1329
+ LlamaContext::GetAudioCompletionGuideTokens(const Napi::CallbackInfo &info) {
1330
+ Napi::Env env = info.Env();
1331
+ if (info.Length() < 1 || !info[0].IsString()) {
1332
+ Napi::TypeError::New(env,
1333
+ "String expected for audio completion guide tokens")
1334
+ .ThrowAsJavaScriptException();
1335
+ return env.Undefined();
1336
+ }
1337
+ auto text = info[0].ToString().Utf8Value();
1338
+ const tts_type type = getTTSType(env);
1339
+ auto clean_text = process_text(text, type);
1340
+ const std::string &delimiter =
1341
+ (type == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
1342
+ const llama_vocab *vocab = llama_model_get_vocab(_sess->model());
1343
+
1344
+ std::vector<int32_t> result;
1345
+ size_t start = 0;
1346
+ size_t end = clean_text.find(delimiter);
1347
+
1348
+ // first token is always a newline, as it was not previously added
1349
+ result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
1350
+
1351
+ while (end != std::string::npos) {
1352
+ std::string current_word = clean_text.substr(start, end - start);
1353
+ auto tmp = common_tokenize(vocab, current_word, false, true);
1354
+ result.push_back(tmp[0]);
1355
+ start = end + delimiter.length();
1356
+ end = clean_text.find(delimiter, start);
1357
+ }
1358
+
1359
+ // Add the last part
1360
+ std::string current_word = clean_text.substr(start);
1361
+ auto tmp = common_tokenize(vocab, current_word, false, true);
1362
+ if (tmp.size() > 0) {
1363
+ result.push_back(tmp[0]);
1364
+ }
1365
+ auto tokens = Napi::Int32Array::New(env, result.size());
1366
+ memcpy(tokens.Data(), result.data(), result.size() * sizeof(int32_t));
1367
+ return tokens;
1368
+ }
1369
+
1370
+ // decodeAudioTokens(tokens: number[]|Int32Array): Float32Array
1371
+ Napi::Value LlamaContext::DecodeAudioTokens(const Napi::CallbackInfo &info) {
1372
+ Napi::Env env = info.Env();
1373
+ if (info.Length() < 1) {
1374
+ Napi::TypeError::New(env, "Tokens parameter is required")
1375
+ .ThrowAsJavaScriptException();
1376
+ }
1377
+ std::vector<int32_t> tokens;
1378
+ if (info[0].IsTypedArray()) {
1379
+ auto js_tokens = info[0].As<Napi::Int32Array>();
1380
+ tokens.resize(js_tokens.ElementLength());
1381
+ memcpy(tokens.data(), js_tokens.Data(),
1382
+ js_tokens.ElementLength() * sizeof(int32_t));
1383
+ } else if (info[0].IsArray()) {
1384
+ auto js_tokens = info[0].As<Napi::Array>();
1385
+ for (size_t i = 0; i < js_tokens.Length(); i++) {
1386
+ tokens.push_back(js_tokens.Get(i).ToNumber().Int32Value());
1387
+ }
1388
+ } else {
1389
+ Napi::TypeError::New(env, "Tokens must be an number array or a Int32Array")
1390
+ .ThrowAsJavaScriptException();
1391
+ return env.Undefined();
1392
+ }
1393
+ tts_type type = getTTSType(env);
1394
+ if (type == UNKNOWN) {
1395
+ Napi::Error::New(env, "Unsupported audio tokens")
1396
+ .ThrowAsJavaScriptException();
1397
+ return env.Undefined();
1398
+ }
1399
+ if (type == OUTETTS_V0_3 || type == OUTETTS_V0_2) {
1400
+ tokens.erase(
1401
+ std::remove_if(tokens.begin(), tokens.end(),
1402
+ [](llama_token t) { return t < 151672 || t > 155772; }),
1403
+ tokens.end());
1404
+ for (auto &token : tokens) {
1405
+ token -= 151672;
1406
+ }
1407
+ }
1408
+ auto worker = new DecodeAudioTokenWorker(
1409
+ info, _vocoder.model.get(), _vocoder.context.get(),
1410
+ _sess->params().cpuparams.n_threads, tokens);
1411
+ worker->Queue();
1412
+ return worker->Promise();
1413
+ }