@fugood/llama.node 1.1.10 → 1.2.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +20 -2
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +174 -388
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +67 -37
  27. package/src/llama.cpp/common/chat.cpp +263 -2
  28. package/src/llama.cpp/common/chat.h +4 -0
  29. package/src/llama.cpp/common/common.cpp +10 -3
  30. package/src/llama.cpp/common/common.h +5 -2
  31. package/src/llama.cpp/common/log.cpp +53 -2
  32. package/src/llama.cpp/common/log.h +10 -4
  33. package/src/llama.cpp/common/sampling.cpp +23 -2
  34. package/src/llama.cpp/common/sampling.h +3 -1
  35. package/src/llama.cpp/common/speculative.cpp +1 -1
  36. package/src/llama.cpp/ggml/CMakeLists.txt +4 -3
  37. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
  38. package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
  39. package/src/llama.cpp/ggml/include/ggml.h +50 -1
  40. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +19 -16
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  42. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -7
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
  45. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  46. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +18 -18
  48. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  49. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +234 -16
  50. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  51. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +80 -51
  52. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +161 -20
  53. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +399 -50
  54. package/src/llama.cpp/include/llama.h +32 -7
  55. package/src/llama.cpp/src/llama-adapter.cpp +101 -4
  56. package/src/llama.cpp/src/llama-adapter.h +6 -0
  57. package/src/llama.cpp/src/llama-arch.cpp +69 -2
  58. package/src/llama.cpp/src/llama-arch.h +6 -0
  59. package/src/llama.cpp/src/llama-context.cpp +92 -45
  60. package/src/llama.cpp/src/llama-context.h +1 -5
  61. package/src/llama.cpp/src/llama-graph.cpp +74 -19
  62. package/src/llama.cpp/src/llama-graph.h +10 -1
  63. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  64. package/src/llama.cpp/src/llama-hparams.h +9 -3
  65. package/src/llama.cpp/src/llama-impl.h +2 -0
  66. package/src/llama.cpp/src/llama-kv-cache.cpp +33 -120
  67. package/src/llama.cpp/src/llama-kv-cache.h +4 -13
  68. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  69. package/src/llama.cpp/src/llama-model.cpp +434 -21
  70. package/src/llama.cpp/src/llama-model.h +1 -1
  71. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  72. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  73. package/src/llama.cpp/src/llama.cpp +12 -0
  74. package/src/anyascii.c +0 -22223
  75. package/src/anyascii.h +0 -42
  76. package/src/tts_utils.cpp +0 -371
  77. package/src/tts_utils.h +0 -103
@@ -1,6 +1,4 @@
1
1
  #include "LlamaContext.h"
2
- #include "DecodeAudioTokenWorker.h"
3
- #include "DetokenizeWorker.h"
4
2
  #include "DisposeWorker.h"
5
3
  #include "EmbeddingWorker.h"
6
4
  #include "RerankWorker.h"
@@ -8,6 +6,8 @@
8
6
  #include "LoadSessionWorker.h"
9
7
  #include "SaveSessionWorker.h"
10
8
  #include "TokenizeWorker.h"
9
+ #include "DetokenizeWorker.h"
10
+ #include "DecodeAudioTokenWorker.h"
11
11
  #include "ggml.h"
12
12
  #include "gguf.h"
13
13
  #include "json-schema-to-grammar.h"
@@ -19,6 +19,8 @@
19
19
  #include <mutex>
20
20
  #include <queue>
21
21
 
22
+ using namespace rnllama;
23
+
22
24
  // Helper function for formatted strings (for console logs)
23
25
  template <typename... Args>
24
26
  static std::string format_string(const std::string &format, Args... args) {
@@ -175,21 +177,6 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
175
177
  exports.Set("LlamaContext", func);
176
178
  }
177
179
 
178
- const std::vector<ggml_type> kv_cache_types = {
179
- GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
180
- GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
181
- GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
182
- };
183
-
184
- static ggml_type kv_cache_type_from_str(const std::string &s) {
185
- for (const auto &type : kv_cache_types) {
186
- if (ggml_type_name(type) == s) {
187
- return type;
188
- }
189
- }
190
- throw std::runtime_error("Unsupported cache type: " + s);
191
- }
192
-
193
180
  static int32_t pooling_type_from_str(const std::string &s) {
194
181
  if (s == "none")
195
182
  return LLAMA_POOLING_TYPE_NONE;
@@ -242,11 +229,18 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
242
229
  params.cpuparams.n_threads =
243
230
  get_option<int32_t>(options, "n_threads", cpu_get_num_math() / 2);
244
231
  params.n_gpu_layers = get_option<int32_t>(options, "n_gpu_layers", -1);
245
- params.flash_attn = get_option<bool>(options, "flash_attn", false);
232
+
233
+ auto flash_attn_type = get_option<std::string>(options, "flash_attn_type", "auto");
234
+ if (!flash_attn_type.empty()) {
235
+ params.flash_attn_type = (enum llama_flash_attn_type)flash_attn_type_from_str(flash_attn_type.c_str());
236
+ } else {
237
+ params.flash_attn_type = get_option<bool>(options, "flash_attn", false) ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
238
+ }
239
+
246
240
  params.cache_type_k = kv_cache_type_from_str(
247
- get_option<std::string>(options, "cache_type_k", "f16").c_str());
241
+ get_option<std::string>(options, "cache_type_k", "f16"));
248
242
  params.cache_type_v = kv_cache_type_from_str(
249
- get_option<std::string>(options, "cache_type_v", "f16").c_str());
243
+ get_option<std::string>(options, "cache_type_v", "f16"));
250
244
  params.ctx_shift = get_option<bool>(options, "ctx_shift", true);
251
245
  params.kv_unified = get_option<bool>(options, "kv_unified", false);
252
246
  params.swa_full = get_option<bool>(options, "swa_full", false);
@@ -272,59 +266,55 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
272
266
  llama_backend_init();
273
267
  llama_numa_init(params.numa);
274
268
 
275
- auto sess = std::make_shared<LlamaSession>(params);
276
-
277
- if (sess->model() == nullptr || sess->context() == nullptr) {
278
- Napi::TypeError::New(env, "Failed to load model")
279
- .ThrowAsJavaScriptException();
280
- }
281
-
282
- auto ctx = sess->context();
283
- auto model = sess->model();
284
-
285
269
  std::vector<common_adapter_lora_info> lora;
286
270
  auto lora_path = get_option<std::string>(options, "lora", "");
287
271
  auto lora_scaled = get_option<float>(options, "lora_scaled", 1.0f);
288
- if (lora_path != "") {
272
+ if (!lora_path.empty()) {
289
273
  common_adapter_lora_info la;
290
274
  la.path = lora_path;
291
275
  la.scale = lora_scaled;
292
- la.ptr = llama_adapter_lora_init(model, lora_path.c_str());
293
- if (la.ptr == nullptr) {
294
- Napi::TypeError::New(env, "Failed to load lora adapter")
295
- .ThrowAsJavaScriptException();
296
- }
297
276
  lora.push_back(la);
298
277
  }
299
278
 
300
279
  if (options.Has("lora_list") && options.Get("lora_list").IsArray()) {
301
280
  auto lora_list = options.Get("lora_list").As<Napi::Array>();
302
- if (lora_list != nullptr) {
303
- int lora_list_size = lora_list.Length();
304
- for (int i = 0; i < lora_list_size; i++) {
305
- auto lora_adapter = lora_list.Get(i).As<Napi::Object>();
306
- auto path = lora_adapter.Get("path").ToString();
307
- if (path != nullptr) {
308
- common_adapter_lora_info la;
309
- la.path = path;
310
- la.scale = lora_adapter.Get("scaled").ToNumber().FloatValue();
311
- la.ptr = llama_adapter_lora_init(model, path.Utf8Value().c_str());
312
- if (la.ptr == nullptr) {
313
- Napi::TypeError::New(env, "Failed to load lora adapter")
314
- .ThrowAsJavaScriptException();
315
- }
316
- lora.push_back(la);
317
- }
281
+ for (uint32_t i = 0; i < lora_list.Length(); i++) {
282
+ auto lora_adapter = lora_list.Get(i).As<Napi::Object>();
283
+ if (lora_adapter.Has("path")) {
284
+ common_adapter_lora_info la;
285
+ la.path = lora_adapter.Get("path").ToString();
286
+ la.scale = lora_adapter.Get("scaled").ToNumber().FloatValue();
287
+ lora.push_back(la);
318
288
  }
319
289
  }
320
290
  }
321
- common_set_adapter_lora(ctx, lora);
322
- _lora = lora;
291
+ // Use rn-llama context instead of direct session
292
+ _rn_ctx = new llama_rn_context();
293
+ if (!_rn_ctx->loadModel(params)) {
294
+ delete _rn_ctx;
295
+ _rn_ctx = nullptr;
296
+ Napi::TypeError::New(env, "Failed to load model").ThrowAsJavaScriptException();
297
+ }
323
298
 
324
- _sess = sess;
299
+ // Handle LoRA adapters through rn-llama
300
+ if (!lora.empty()) {
301
+ _rn_ctx->applyLoraAdapters(lora);
302
+ }
303
+
325
304
  _info = common_params_get_system_info(params);
305
+ }
326
306
 
327
- _templates = common_chat_templates_init(model, params.chat_template);
307
+ LlamaContext::~LlamaContext() {
308
+ // The DisposeWorker is responsible for cleanup of _rn_ctx
309
+ // If _rn_ctx is still not null here, it means disposal was not properly initiated
310
+ if (_rn_ctx) {
311
+ try {
312
+ delete _rn_ctx;
313
+ _rn_ctx = nullptr;
314
+ } catch (...) {
315
+ // Ignore errors during cleanup to avoid crashes in destructor
316
+ }
317
+ }
328
318
  }
329
319
 
330
320
  // getSystemInfo(): string
@@ -332,14 +322,6 @@ Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) {
332
322
  return Napi::String::New(info.Env(), _info);
333
323
  }
334
324
 
335
- bool validateModelChatTemplate(const struct llama_model *model,
336
- const bool use_jinja, const char *name) {
337
- const char *tmpl = llama_model_chat_template(model, name);
338
- if (tmpl == nullptr) {
339
- return false;
340
- }
341
- return common_chat_verify_template(tmpl, use_jinja);
342
- }
343
325
 
344
326
  // Store log messages for processing
345
327
  struct LogMessage {
@@ -434,8 +416,12 @@ void LlamaContext::ToggleNativeLog(const Napi::CallbackInfo &info) {
434
416
 
435
417
  // getModelInfo(): object
436
418
  Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
419
+ if (!_rn_ctx || !_rn_ctx->model) {
420
+ Napi::TypeError::New(info.Env(), "Model not loaded")
421
+ .ThrowAsJavaScriptException();
422
+ }
437
423
  char desc[1024];
438
- auto model = _sess->model();
424
+ auto model = _rn_ctx->model;
439
425
  llama_model_desc(model, desc, sizeof(desc));
440
426
 
441
427
  int count = llama_model_meta_count(model);
@@ -455,51 +441,38 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
455
441
  details.Set("size", llama_model_size(model));
456
442
 
457
443
  Napi::Object chatTemplates = Napi::Object::New(info.Env());
458
- chatTemplates.Set("llamaChat", validateModelChatTemplate(model, false, nullptr));
444
+ chatTemplates.Set("llamaChat", _rn_ctx->validateModelChatTemplate(false, nullptr));
459
445
  Napi::Object minja = Napi::Object::New(info.Env());
460
- minja.Set("default", validateModelChatTemplate(model, true, nullptr));
446
+ minja.Set("default", _rn_ctx->validateModelChatTemplate(true, nullptr));
461
447
  Napi::Object defaultCaps = Napi::Object::New(info.Env());
448
+ auto default_tmpl = _rn_ctx->templates.get()->template_default.get();
449
+ auto default_tmpl_caps = default_tmpl->original_caps();
462
450
  defaultCaps.Set(
463
451
  "tools",
464
- _templates.get()->template_default->original_caps().supports_tools);
452
+ default_tmpl_caps.supports_tools);
465
453
  defaultCaps.Set(
466
454
  "toolCalls",
467
- _templates.get()->template_default->original_caps().supports_tool_calls);
468
- defaultCaps.Set("toolResponses", _templates.get()
469
- ->template_default->original_caps()
470
- .supports_tool_responses);
455
+ default_tmpl_caps.supports_tool_calls);
456
+ defaultCaps.Set("toolResponses", default_tmpl_caps.supports_tool_responses);
471
457
  defaultCaps.Set(
472
458
  "systemRole",
473
- _templates.get()->template_default->original_caps().supports_system_role);
474
- defaultCaps.Set("parallelToolCalls", _templates.get()
475
- ->template_default->original_caps()
476
- .supports_parallel_tool_calls);
477
- defaultCaps.Set("toolCallId", _templates.get()
478
- ->template_default->original_caps()
479
- .supports_tool_call_id);
459
+ default_tmpl_caps.supports_system_role);
460
+ defaultCaps.Set("parallelToolCalls", default_tmpl_caps.supports_parallel_tool_calls);
461
+ defaultCaps.Set("toolCallId", default_tmpl_caps.supports_tool_call_id);
480
462
  minja.Set("defaultCaps", defaultCaps);
481
- minja.Set("toolUse", validateModelChatTemplate(model, true, "tool_use"));
482
- if (_templates.get()->template_tool_use) {
463
+ minja.Set("toolUse", _rn_ctx->validateModelChatTemplate(true, "tool_use"));
464
+ if (_rn_ctx->validateModelChatTemplate(true, "tool_use")) {
483
465
  Napi::Object toolUseCaps = Napi::Object::New(info.Env());
466
+ auto tool_use_tmpl = _rn_ctx->templates.get()->template_tool_use.get();
467
+ auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
484
468
  toolUseCaps.Set(
485
469
  "tools",
486
- _templates.get()->template_tool_use->original_caps().supports_tools);
487
- toolUseCaps.Set("toolCalls", _templates.get()
488
- ->template_tool_use->original_caps()
489
- .supports_tool_calls);
490
- toolUseCaps.Set("toolResponses", _templates.get()
491
- ->template_tool_use->original_caps()
492
- .supports_tool_responses);
493
- toolUseCaps.Set("systemRole", _templates.get()
494
- ->template_tool_use->original_caps()
495
- .supports_system_role);
496
- toolUseCaps.Set("parallelToolCalls",
497
- _templates.get()
498
- ->template_tool_use->original_caps()
499
- .supports_parallel_tool_calls);
500
- toolUseCaps.Set("toolCallId", _templates.get()
501
- ->template_tool_use->original_caps()
502
- .supports_tool_call_id);
470
+ tool_use_tmpl_caps.supports_tools);
471
+ toolUseCaps.Set("toolCalls", tool_use_tmpl_caps.supports_tool_calls);
472
+ toolUseCaps.Set("toolResponses", tool_use_tmpl_caps.supports_tool_responses);
473
+ toolUseCaps.Set("systemRole", tool_use_tmpl_caps.supports_system_role);
474
+ toolUseCaps.Set("parallelToolCalls", tool_use_tmpl_caps.supports_parallel_tool_calls);
475
+ toolUseCaps.Set("toolCallId", tool_use_tmpl_caps.supports_tool_call_id);
503
476
  minja.Set("toolUseCaps", toolUseCaps);
504
477
  }
505
478
  chatTemplates.Set("minja", minja);
@@ -509,76 +482,11 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
509
482
 
510
483
  // Deprecated: use chatTemplates.llamaChat instead
511
484
  details.Set("isChatTemplateSupported",
512
- validateModelChatTemplate(_sess->model(), false, nullptr));
485
+ _rn_ctx->validateModelChatTemplate(false, nullptr));
513
486
  return details;
514
487
  }
515
488
 
516
- common_chat_params getFormattedChatWithJinja(
517
- const std::shared_ptr<LlamaSession> &sess,
518
- const common_chat_templates_ptr &templates, const std::string &messages,
519
- const std::string &chat_template, const std::string &json_schema,
520
- const std::string &tools, const bool &parallel_tool_calls,
521
- const std::string &tool_choice,
522
- const bool &enable_thinking,
523
- const bool &add_generation_prompt,
524
- const std::string &now_str,
525
- const std::map<std::string, std::string> &chat_template_kwargs
526
- ) {
527
- common_chat_templates_inputs inputs;
528
- inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
529
- auto useTools = !tools.empty();
530
- if (useTools) {
531
- inputs.tools = common_chat_tools_parse_oaicompat(json::parse(tools));
532
- }
533
- inputs.parallel_tool_calls = parallel_tool_calls;
534
- if (!tool_choice.empty()) {
535
- inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
536
- }
537
- if (!json_schema.empty()) {
538
- inputs.json_schema = json::parse(json_schema);
539
- }
540
- inputs.enable_thinking = enable_thinking;
541
- inputs.add_generation_prompt = add_generation_prompt;
542
-
543
- // Handle now parameter - parse timestamp or use current time
544
- if (!now_str.empty()) {
545
- try {
546
- // Try to parse as timestamp (seconds since epoch)
547
- auto timestamp = std::stoll(now_str);
548
- inputs.now = std::chrono::system_clock::from_time_t(timestamp);
549
- } catch (...) {
550
- // If parsing fails, use current time
551
- inputs.now = std::chrono::system_clock::now();
552
- }
553
- }
554
-
555
- inputs.chat_template_kwargs = chat_template_kwargs;
556
-
557
- // If chat_template is provided, create new one and use it (probably slow)
558
- if (!chat_template.empty()) {
559
- auto tmps = common_chat_templates_init(sess->model(), chat_template);
560
- return common_chat_templates_apply(tmps.get(), inputs);
561
- } else {
562
- return common_chat_templates_apply(templates.get(), inputs);
563
- }
564
- }
565
489
 
566
- std::string getFormattedChat(const struct llama_model *model,
567
- const common_chat_templates_ptr &templates,
568
- const std::string &messages,
569
- const std::string &chat_template) {
570
- common_chat_templates_inputs inputs;
571
- inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
572
- inputs.use_jinja = false;
573
-
574
- // If chat_template is provided, create new one and use it (probably slow)
575
- if (!chat_template.empty()) {
576
- auto tmps = common_chat_templates_init(model, chat_template);
577
- return common_chat_templates_apply(tmps.get(), inputs).prompt;
578
- } else {
579
- return common_chat_templates_apply(templates.get(), inputs).prompt;
580
- }
581
- }
582
490
 
583
491
  // getFormattedChat(
584
492
  // messages: [{ role: string, content: string }],
@@ -588,13 +496,16 @@ std::string getFormattedChat(const struct llama_model *model,
588
496
  // ): object | string
589
497
  Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
590
498
  Napi::Env env = info.Env();
499
+ if (!_rn_ctx) {
500
+ Napi::TypeError::New(env, "Context is disposed").ThrowAsJavaScriptException();
501
+ }
591
502
  if (info.Length() < 1 || !info[0].IsArray()) {
592
503
  Napi::TypeError::New(env, "Array expected").ThrowAsJavaScriptException();
593
504
  }
594
505
  auto messages = json_stringify(info[0].As<Napi::Array>());
595
506
  auto chat_template = info[1].IsString() ? info[1].ToString().Utf8Value() : "";
596
507
 
597
- auto has_params = info.Length() >= 2;
508
+ auto has_params = info.Length() >= 3;
598
509
  auto params =
599
510
  has_params ? info[2].As<Napi::Object>() : Napi::Object::New(env);
600
511
 
@@ -643,8 +554,8 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
643
554
 
644
555
  common_chat_params chatParams;
645
556
  try {
646
- chatParams = getFormattedChatWithJinja(
647
- _sess, _templates, messages, chat_template, json_schema_str, tools_str,
557
+ chatParams = _rn_ctx->getFormattedChatWithJinja(
558
+ messages, chat_template, json_schema_str, tools_str,
648
559
  parallel_tool_calls, tool_choice, enable_thinking,
649
560
  add_generation_prompt, now_str, chat_template_kwargs);
650
561
  } catch (const nlohmann::json_abi_v3_12_0::detail::parse_error& e) {
@@ -699,7 +610,7 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
699
610
  return result;
700
611
  } else {
701
612
  auto formatted =
702
- getFormattedChat(_sess->model(), _templates, messages, chat_template);
613
+ _rn_ctx->getFormattedChat(messages, chat_template);
703
614
  return Napi::String::New(env, formatted);
704
615
  }
705
616
  }
@@ -714,7 +625,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
714
625
  if (info.Length() >= 2 && !info[1].IsFunction()) {
715
626
  Napi::TypeError::New(env, "Function expected").ThrowAsJavaScriptException();
716
627
  }
717
- if (_sess == nullptr) {
628
+ if (!_rn_ctx) {
718
629
  Napi::TypeError::New(env, "Context is disposed")
719
630
  .ThrowAsJavaScriptException();
720
631
  }
@@ -746,7 +657,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
746
657
  }
747
658
 
748
659
  // Check if multimodal is enabled when media_paths are provided
749
- if (!media_paths.empty() && !(_has_multimodal && _mtmd_ctx != nullptr)) {
660
+ if (!media_paths.empty() && !(_rn_ctx->has_multimodal && _rn_ctx->mtmd_wrapper != nullptr)) {
750
661
  Napi::Error::New(env, "Multimodal support must be enabled via "
751
662
  "initMultimodal to use media_paths")
752
663
  .ThrowAsJavaScriptException();
@@ -757,7 +668,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
757
668
  bool thinking_forced_open = get_option<bool>(options, "thinking_forced_open", false);
758
669
  std::string reasoning_format = get_option<std::string>(options, "reasoning_format", "none");
759
670
 
760
- common_params params = _sess->params();
671
+ common_params params = _rn_ctx->params;
761
672
  auto grammar_from_params = get_option<std::string>(options, "grammar", "");
762
673
  auto has_grammar_set = !grammar_from_params.empty();
763
674
  if (has_grammar_set) {
@@ -790,7 +701,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
790
701
  for (size_t i = 0; i < preserved_tokens.Length(); i++) {
791
702
  auto token = preserved_tokens.Get(i).ToString().Utf8Value();
792
703
  auto ids =
793
- common_tokenize(_sess->context(), token, /* add_special= */ false,
704
+ common_tokenize(_rn_ctx->ctx, token, /* add_special= */ false,
794
705
  /* parse_special= */ true);
795
706
  if (ids.size() == 1) {
796
707
  params.sampling.preserved_tokens.insert(ids[0]);
@@ -810,7 +721,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
810
721
 
811
722
  if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
812
723
  auto ids =
813
- common_tokenize(_sess->context(), word, /* add_special= */ false,
724
+ common_tokenize(_rn_ctx->ctx, word, /* add_special= */ false,
814
725
  /* parse_special= */ true);
815
726
  if (ids.size() == 1) {
816
727
  auto token = ids[0];
@@ -881,8 +792,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
881
792
  common_chat_params chatParams;
882
793
 
883
794
  try {
884
- chatParams = getFormattedChatWithJinja(
885
- _sess, _templates, json_stringify(messages), chat_template,
795
+ chatParams = _rn_ctx->getFormattedChatWithJinja(
796
+ json_stringify(messages), chat_template,
886
797
  json_schema_str, tools_str, parallel_tool_calls, tool_choice, enable_thinking,
887
798
  add_generation_prompt, now_str, chat_template_kwargs);
888
799
  } catch (const std::exception &e) {
@@ -897,7 +808,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
897
808
 
898
809
  for (const auto &token : chatParams.preserved_tokens) {
899
810
  auto ids =
900
- common_tokenize(_sess->context(), token, /* add_special= */ false,
811
+ common_tokenize(_rn_ctx->ctx, token, /* add_special= */ false,
901
812
  /* parse_special= */ true);
902
813
  if (ids.size() == 1) {
903
814
  params.sampling.preserved_tokens.insert(ids[0]);
@@ -918,8 +829,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
918
829
  stop_words.push_back(stop);
919
830
  }
920
831
  } else {
921
- auto formatted = getFormattedChat(
922
- _sess->model(), _templates, json_stringify(messages), chat_template);
832
+ auto formatted = _rn_ctx->getFormattedChat(
833
+ json_stringify(messages), chat_template);
923
834
  params.prompt = formatted;
924
835
  }
925
836
  } else {
@@ -973,6 +884,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
973
884
  params.n_keep = get_option<int32_t>(options, "n_keep", 0);
974
885
  params.sampling.seed =
975
886
  get_option<int32_t>(options, "seed", LLAMA_DEFAULT_SEED);
887
+ params.sampling.n_probs = get_option<int32_t>(options, "n_probs", 0);
976
888
 
977
889
  // guide_tokens
978
890
  std::vector<llama_token> guide_tokens;
@@ -1007,9 +919,9 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
1007
919
  }
1008
920
 
1009
921
  auto *worker =
1010
- new LlamaCompletionWorker(info, _sess, callback, params, stop_words,
922
+ new LlamaCompletionWorker(info, _rn_ctx, callback, params, stop_words,
1011
923
  chat_format, thinking_forced_open, reasoning_format, media_paths, guide_tokens,
1012
- _has_vocoder, _tts_type, prefill_text);
924
+ _rn_ctx->has_vocoder, _rn_ctx->tts_wrapper ? _rn_ctx->tts_wrapper->type : rnllama::UNKNOWN, prefill_text);
1013
925
  worker->Queue();
1014
926
  _wip = worker;
1015
927
  worker->OnComplete([this]() { _wip = nullptr; });
@@ -1023,25 +935,28 @@ void LlamaContext::StopCompletion(const Napi::CallbackInfo &info) {
1023
935
  }
1024
936
  }
1025
937
 
1026
- // tokenize(text: string): Promise<TokenizeResult>
938
+ // tokenize(text: string, ): Promise<TokenizeResult>
1027
939
  Napi::Value LlamaContext::Tokenize(const Napi::CallbackInfo &info) {
1028
940
  Napi::Env env = info.Env();
1029
941
  if (info.Length() < 1 || !info[0].IsString()) {
1030
942
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1031
943
  }
1032
- if (_sess == nullptr) {
944
+ if (!_rn_ctx) {
1033
945
  Napi::TypeError::New(env, "Context is disposed")
1034
946
  .ThrowAsJavaScriptException();
1035
947
  }
1036
948
  auto text = info[0].ToString().Utf8Value();
1037
949
  std::vector<std::string> media_paths;
950
+
1038
951
  if (info.Length() >= 2 && info[1].IsArray()) {
952
+ // Direct array format: tokenize(text, [media_paths])
1039
953
  auto media_paths_array = info[1].As<Napi::Array>();
1040
954
  for (size_t i = 0; i < media_paths_array.Length(); i++) {
1041
955
  media_paths.push_back(media_paths_array.Get(i).ToString().Utf8Value());
1042
956
  }
1043
957
  }
1044
- auto *worker = new TokenizeWorker(info, _sess, text, media_paths);
958
+
959
+ auto *worker = new TokenizeWorker(info, _rn_ctx, text, media_paths);
1045
960
  worker->Queue();
1046
961
  return worker->Promise();
1047
962
  }
@@ -1052,7 +967,7 @@ Napi::Value LlamaContext::Detokenize(const Napi::CallbackInfo &info) {
1052
967
  if (info.Length() < 1 || !info[0].IsArray()) {
1053
968
  Napi::TypeError::New(env, "Array expected").ThrowAsJavaScriptException();
1054
969
  }
1055
- if (_sess == nullptr) {
970
+ if (!_rn_ctx) {
1056
971
  Napi::TypeError::New(env, "Context is disposed")
1057
972
  .ThrowAsJavaScriptException();
1058
973
  }
@@ -1061,7 +976,8 @@ Napi::Value LlamaContext::Detokenize(const Napi::CallbackInfo &info) {
1061
976
  for (size_t i = 0; i < tokens.Length(); i++) {
1062
977
  token_ids.push_back(tokens.Get(i).ToNumber().Int32Value());
1063
978
  }
1064
- auto *worker = new DetokenizeWorker(info, _sess, token_ids);
979
+
980
+ auto *worker = new DetokenizeWorker(info, _rn_ctx, token_ids);
1065
981
  worker->Queue();
1066
982
  return worker->Promise();
1067
983
  }
@@ -1072,7 +988,7 @@ Napi::Value LlamaContext::Embedding(const Napi::CallbackInfo &info) {
1072
988
  if (info.Length() < 1 || !info[0].IsString()) {
1073
989
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1074
990
  }
1075
- if (_sess == nullptr) {
991
+ if (!_rn_ctx) {
1076
992
  Napi::TypeError::New(env, "Context is disposed")
1077
993
  .ThrowAsJavaScriptException();
1078
994
  }
@@ -1085,7 +1001,7 @@ Napi::Value LlamaContext::Embedding(const Napi::CallbackInfo &info) {
1085
1001
  embdParams.embedding = true;
1086
1002
  embdParams.embd_normalize = get_option<int32_t>(options, "embd_normalize", 2);
1087
1003
  auto text = info[0].ToString().Utf8Value();
1088
- auto *worker = new EmbeddingWorker(info, _sess, text, embdParams);
1004
+ auto *worker = new EmbeddingWorker(info, _rn_ctx, text, embdParams);
1089
1005
  worker->Queue();
1090
1006
  return worker->Promise();
1091
1007
  }
@@ -1096,7 +1012,7 @@ Napi::Value LlamaContext::Rerank(const Napi::CallbackInfo &info) {
1096
1012
  if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray()) {
1097
1013
  Napi::TypeError::New(env, "Query string and documents array expected").ThrowAsJavaScriptException();
1098
1014
  }
1099
- if (_sess == nullptr) {
1015
+ if (!_rn_ctx) {
1100
1016
  Napi::TypeError::New(env, "Context is disposed")
1101
1017
  .ThrowAsJavaScriptException();
1102
1018
  }
@@ -1119,7 +1035,7 @@ Napi::Value LlamaContext::Rerank(const Napi::CallbackInfo &info) {
1119
1035
  rerankParams.embedding = true;
1120
1036
  rerankParams.embd_normalize = get_option<int32_t>(options, "normalize", -1);
1121
1037
 
1122
- auto *worker = new RerankWorker(info, _sess, query, documents, rerankParams);
1038
+ auto *worker = new RerankWorker(info, _rn_ctx, query, documents, rerankParams);
1123
1039
  worker->Queue();
1124
1040
  return worker->Promise();
1125
1041
  }
@@ -1130,17 +1046,17 @@ Napi::Value LlamaContext::SaveSession(const Napi::CallbackInfo &info) {
1130
1046
  if (info.Length() < 1 || !info[0].IsString()) {
1131
1047
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1132
1048
  }
1133
- if (_sess == nullptr) {
1049
+ if (!_rn_ctx) {
1134
1050
  Napi::TypeError::New(env, "Context is disposed")
1135
1051
  .ThrowAsJavaScriptException();
1136
1052
  }
1137
1053
  #ifdef GGML_USE_VULKAN
1138
- if (_sess->params().n_gpu_layers > 0) {
1054
+ if (_rn_ctx->params.n_gpu_layers > 0) {
1139
1055
  Napi::TypeError::New(env, "Vulkan cannot save session")
1140
1056
  .ThrowAsJavaScriptException();
1141
1057
  }
1142
1058
  #endif
1143
- auto *worker = new SaveSessionWorker(info, _sess);
1059
+ auto *worker = new SaveSessionWorker(info, _rn_ctx);
1144
1060
  worker->Queue();
1145
1061
  return worker->Promise();
1146
1062
  }
@@ -1151,17 +1067,17 @@ Napi::Value LlamaContext::LoadSession(const Napi::CallbackInfo &info) {
1151
1067
  if (info.Length() < 1 || !info[0].IsString()) {
1152
1068
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1153
1069
  }
1154
- if (_sess == nullptr) {
1070
+ if (!_rn_ctx) {
1155
1071
  Napi::TypeError::New(env, "Context is disposed")
1156
1072
  .ThrowAsJavaScriptException();
1157
1073
  }
1158
1074
  #ifdef GGML_USE_VULKAN
1159
- if (_sess->params().n_gpu_layers > 0) {
1075
+ if (_rn_ctx->params.n_gpu_layers > 0) {
1160
1076
  Napi::TypeError::New(env, "Vulkan cannot load session")
1161
1077
  .ThrowAsJavaScriptException();
1162
1078
  }
1163
1079
  #endif
1164
- auto *worker = new LoadSessionWorker(info, _sess);
1080
+ auto *worker = new LoadSessionWorker(info, _rn_ctx);
1165
1081
  worker->Queue();
1166
1082
  return worker->Promise();
1167
1083
  }
@@ -1169,6 +1085,9 @@ Napi::Value LlamaContext::LoadSession(const Napi::CallbackInfo &info) {
1169
1085
  // applyLoraAdapters(lora_adapters: [{ path: string, scaled: number }]): void
1170
1086
  void LlamaContext::ApplyLoraAdapters(const Napi::CallbackInfo &info) {
1171
1087
  Napi::Env env = info.Env();
1088
+ if (!_rn_ctx) {
1089
+ Napi::TypeError::New(env, "Context is disposed").ThrowAsJavaScriptException();
1090
+ }
1172
1091
  std::vector<common_adapter_lora_info> lora;
1173
1092
  auto lora_adapters = info[0].As<Napi::Array>();
1174
1093
  for (size_t i = 0; i < lora_adapters.Length(); i++) {
@@ -1178,21 +1097,16 @@ void LlamaContext::ApplyLoraAdapters(const Napi::CallbackInfo &info) {
1178
1097
  common_adapter_lora_info la;
1179
1098
  la.path = path;
1180
1099
  la.scale = scaled;
1181
- la.ptr = llama_adapter_lora_init(_sess->model(), path.c_str());
1182
- if (la.ptr == nullptr) {
1183
- Napi::TypeError::New(env, "Failed to load lora adapter")
1184
- .ThrowAsJavaScriptException();
1185
- }
1186
1100
  lora.push_back(la);
1187
1101
  }
1188
- common_set_adapter_lora(_sess->context(), lora);
1189
- _lora = lora;
1102
+ _rn_ctx->applyLoraAdapters(lora);
1190
1103
  }
1191
1104
 
1192
1105
  // removeLoraAdapters(): void
1193
1106
  void LlamaContext::RemoveLoraAdapters(const Napi::CallbackInfo &info) {
1194
- _lora.clear();
1195
- common_set_adapter_lora(_sess->context(), _lora);
1107
+ if (_rn_ctx) {
1108
+ _rn_ctx->removeLoraAdapters();
1109
+ }
1196
1110
  }
1197
1111
 
1198
1112
  // getLoadedLoraAdapters(): Promise<{ count, lora_adapters: [{ path: string,
@@ -1200,11 +1114,15 @@ void LlamaContext::RemoveLoraAdapters(const Napi::CallbackInfo &info) {
1200
1114
  Napi::Value
1201
1115
  LlamaContext::GetLoadedLoraAdapters(const Napi::CallbackInfo &info) {
1202
1116
  Napi::Env env = info.Env();
1203
- Napi::Array lora_adapters = Napi::Array::New(env, _lora.size());
1204
- for (size_t i = 0; i < _lora.size(); i++) {
1117
+ if (!_rn_ctx) {
1118
+ Napi::TypeError::New(env, "Context is disposed").ThrowAsJavaScriptException();
1119
+ }
1120
+ auto lora = _rn_ctx->getLoadedLoraAdapters();
1121
+ Napi::Array lora_adapters = Napi::Array::New(env, lora.size());
1122
+ for (size_t i = 0; i < lora.size(); i++) {
1205
1123
  Napi::Object lora_adapter = Napi::Object::New(env);
1206
- lora_adapter.Set("path", _lora[i].path);
1207
- lora_adapter.Set("scaled", _lora[i].scale);
1124
+ lora_adapter.Set("path", lora[i].path);
1125
+ lora_adapter.Set("scaled", lora[i].scale);
1208
1126
  lora_adapters.Set(i, lora_adapter);
1209
1127
  }
1210
1128
  return lora_adapters;
@@ -1217,18 +1135,13 @@ Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) {
1217
1135
  _wip->SetStop();
1218
1136
  }
1219
1137
 
1220
- if (_sess == nullptr) {
1138
+ if (_rn_ctx == nullptr) {
1221
1139
  auto promise = Napi::Promise::Deferred(env);
1222
1140
  promise.Resolve(env.Undefined());
1223
1141
  return promise.Promise();
1224
1142
  }
1225
1143
 
1226
- // Clear the mtmd context reference in the session
1227
- if (_mtmd_ctx != nullptr) {
1228
- _sess->set_mtmd_ctx(nullptr);
1229
- }
1230
-
1231
- auto *worker = new DisposeWorker(info, std::move(_sess));
1144
+ auto *worker = new DisposeWorker(info, _rn_ctx, &_rn_ctx);
1232
1145
  worker->Queue();
1233
1146
  return worker->Promise();
1234
1147
  }
@@ -1242,13 +1155,6 @@ extern "C" void cleanup_logging() {
1242
1155
  }
1243
1156
  }
1244
1157
 
1245
- LlamaContext::~LlamaContext() {
1246
- if (_mtmd_ctx != nullptr) {
1247
- mtmd_free(_mtmd_ctx);
1248
- _mtmd_ctx = nullptr;
1249
- _has_multimodal = false;
1250
- }
1251
- }
1252
1158
 
1253
1159
  // initMultimodal(options: { path: string, use_gpu?: boolean }): boolean
1254
1160
  Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
@@ -1270,50 +1176,20 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1270
1176
 
1271
1177
  console_log(env, "Initializing multimodal with mmproj path: " + mmproj_path);
1272
1178
 
1273
- auto model = _sess->model();
1274
- auto ctx = _sess->context();
1275
- if (model == nullptr) {
1179
+ if (_rn_ctx->model == nullptr) {
1276
1180
  Napi::Error::New(env, "Model not loaded").ThrowAsJavaScriptException();
1277
1181
  return Napi::Boolean::New(env, false);
1278
1182
  }
1279
1183
 
1280
- if (_mtmd_ctx != nullptr) {
1281
- mtmd_free(_mtmd_ctx);
1282
- _mtmd_ctx = nullptr;
1283
- _has_multimodal = false;
1284
- }
1285
-
1286
- // Initialize mtmd context
1287
- mtmd_context_params mtmd_params = mtmd_context_params_default();
1288
- mtmd_params.use_gpu = use_gpu;
1289
- mtmd_params.print_timings = false;
1290
- mtmd_params.n_threads = _sess->params().cpuparams.n_threads;
1291
- mtmd_params.verbosity = (ggml_log_level)GGML_LOG_LEVEL_INFO;
1292
-
1293
- console_log(env, format_string(
1294
- "Initializing mtmd context with threads=%d, use_gpu=%d",
1295
- mtmd_params.n_threads, mtmd_params.use_gpu ? 1 : 0));
1296
-
1297
- _mtmd_ctx = mtmd_init_from_file(mmproj_path.c_str(), model, mtmd_params);
1298
- if (_mtmd_ctx == nullptr) {
1184
+ // Disable ctx_shift before initializing multimodal
1185
+ _rn_ctx->params.ctx_shift = false;
1186
+ bool result = _rn_ctx->initMultimodal(mmproj_path, use_gpu);
1187
+ if (!result) {
1299
1188
  Napi::Error::New(env, "Failed to initialize multimodal context")
1300
1189
  .ThrowAsJavaScriptException();
1301
1190
  return Napi::Boolean::New(env, false);
1302
1191
  }
1303
1192
 
1304
- _has_multimodal = true;
1305
-
1306
- // Share the mtmd context with the session
1307
- _sess->set_mtmd_ctx(_mtmd_ctx);
1308
-
1309
- // Check if the model uses M-RoPE or non-causal attention
1310
- bool uses_mrope = mtmd_decode_use_mrope(_mtmd_ctx);
1311
- bool uses_non_causal = mtmd_decode_use_non_causal(_mtmd_ctx);
1312
- console_log(
1313
- env, format_string(
1314
- "Model multimodal properties: uses_mrope=%d, uses_non_causal=%d",
1315
- uses_mrope ? 1 : 0, uses_non_causal ? 1 : 0));
1316
-
1317
1193
  console_log(env, "Multimodal context initialized successfully with mmproj: " +
1318
1194
  mmproj_path);
1319
1195
  return Napi::Boolean::New(env, true);
@@ -1321,8 +1197,7 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1321
1197
 
1322
1198
  // isMultimodalEnabled(): boolean
1323
1199
  Napi::Value LlamaContext::IsMultimodalEnabled(const Napi::CallbackInfo &info) {
1324
- return Napi::Boolean::New(info.Env(),
1325
- _has_multimodal && _mtmd_ctx != nullptr);
1200
+ return Napi::Boolean::New(info.Env(), _rn_ctx->isMultimodalEnabled());
1326
1201
  }
1327
1202
 
1328
1203
  // getMultimodalSupport(): Promise<{ vision: boolean, audio: boolean }>
@@ -1330,10 +1205,10 @@ Napi::Value LlamaContext::GetMultimodalSupport(const Napi::CallbackInfo &info) {
1330
1205
  Napi::Env env = info.Env();
1331
1206
  auto result = Napi::Object::New(env);
1332
1207
 
1333
- if (_has_multimodal && _mtmd_ctx != nullptr) {
1208
+ if (_rn_ctx->isMultimodalEnabled()) {
1334
1209
  result.Set("vision",
1335
- Napi::Boolean::New(env, mtmd_support_vision(_mtmd_ctx)));
1336
- result.Set("audio", Napi::Boolean::New(env, mtmd_support_audio(_mtmd_ctx)));
1210
+ Napi::Boolean::New(env, _rn_ctx->isMultimodalSupportVision()));
1211
+ result.Set("audio", Napi::Boolean::New(env, _rn_ctx->isMultimodalSupportAudio()));
1337
1212
  } else {
1338
1213
  result.Set("vision", Napi::Boolean::New(env, false));
1339
1214
  result.Set("audio", Napi::Boolean::New(env, false));
@@ -1344,42 +1219,14 @@ Napi::Value LlamaContext::GetMultimodalSupport(const Napi::CallbackInfo &info) {
1344
1219
 
1345
1220
  // releaseMultimodal(): void
1346
1221
  void LlamaContext::ReleaseMultimodal(const Napi::CallbackInfo &info) {
1347
- if (_mtmd_ctx != nullptr) {
1348
- // Clear the mtmd context reference in the session
1349
- if (_sess != nullptr) {
1350
- _sess->set_mtmd_ctx(nullptr);
1351
- }
1352
-
1353
- // Free the mtmd context
1354
- mtmd_free(_mtmd_ctx);
1355
- _mtmd_ctx = nullptr;
1356
- _has_multimodal = false;
1357
- }
1222
+ _rn_ctx->releaseMultimodal();
1358
1223
  }
1359
1224
 
1360
- tts_type LlamaContext::getTTSType(Napi::Env env, nlohmann::json speaker) {
1361
- if (speaker.is_object() && speaker.contains("version")) {
1362
- std::string version = speaker["version"].get<std::string>();
1363
- if (version == "0.2") {
1364
- return OUTETTS_V0_2;
1365
- } else if (version == "0.3") {
1366
- return OUTETTS_V0_3;
1367
- } else {
1368
- Napi::Error::New(env, format_string("Unsupported speaker version '%s'\n",
1369
- version.c_str()))
1370
- .ThrowAsJavaScriptException();
1371
- return UNKNOWN;
1372
- }
1225
+ rnllama::tts_type LlamaContext::getTTSType(Napi::Env env, nlohmann::json speaker) {
1226
+ if (_rn_ctx->tts_wrapper) {
1227
+ return _rn_ctx->tts_wrapper->getTTSType(_rn_ctx, speaker);
1373
1228
  }
1374
- if (_tts_type != UNKNOWN) {
1375
- return _tts_type;
1376
- }
1377
- const char *chat_template =
1378
- llama_model_chat_template(_sess->model(), nullptr);
1379
- if (chat_template && std::string(chat_template) == "outetts-0.3") {
1380
- return OUTETTS_V0_3;
1381
- }
1382
- return OUTETTS_V0_2;
1229
+ return rnllama::UNKNOWN;
1383
1230
  }
1384
1231
 
1385
1232
  // initVocoder(params?: object): boolean
@@ -1391,49 +1238,34 @@ Napi::Value LlamaContext::InitVocoder(const Napi::CallbackInfo &info) {
1391
1238
  }
1392
1239
  auto options = info[0].As<Napi::Object>();
1393
1240
  auto vocoder_path = options.Get("path").ToString().Utf8Value();
1394
- auto n_batch = get_option<int32_t>(options, "n_batch", _sess->params().n_batch);
1241
+ auto n_batch = get_option<int32_t>(options, "n_batch", _rn_ctx->params.n_batch);
1395
1242
  if (vocoder_path.empty()) {
1396
1243
  Napi::TypeError::New(env, "vocoder path is required")
1397
1244
  .ThrowAsJavaScriptException();
1398
1245
  }
1399
- if (_has_vocoder) {
1246
+ if (_rn_ctx->has_vocoder) {
1400
1247
  Napi::Error::New(env, "Vocoder already initialized")
1401
1248
  .ThrowAsJavaScriptException();
1402
1249
  return Napi::Boolean::New(env, false);
1403
1250
  }
1404
- _tts_type = getTTSType(env);
1405
- _vocoder.params = _sess->params();
1406
- _vocoder.params.warmup = false;
1407
- _vocoder.params.model.path = vocoder_path;
1408
- _vocoder.params.embedding = true;
1409
- _vocoder.params.ctx_shift = false;
1410
- _vocoder.params.n_batch = n_batch;
1411
- _vocoder.params.n_ubatch = _vocoder.params.n_batch;
1412
- common_init_result result = common_init_from_params(_vocoder.params);
1413
- if (result.model == nullptr || result.context == nullptr) {
1251
+ bool result = _rn_ctx->initVocoder(vocoder_path, n_batch);
1252
+ if (!result) {
1414
1253
  Napi::Error::New(env, "Failed to initialize vocoder")
1415
1254
  .ThrowAsJavaScriptException();
1416
1255
  return Napi::Boolean::New(env, false);
1417
1256
  }
1418
- _vocoder.model = std::move(result.model);
1419
- _vocoder.context = std::move(result.context);
1420
- _has_vocoder = true;
1421
1257
  return Napi::Boolean::New(env, true);
1422
1258
  }
1423
1259
 
1424
1260
  // releaseVocoder(): void
1425
1261
  void LlamaContext::ReleaseVocoder(const Napi::CallbackInfo &info) {
1426
- if (_has_vocoder) {
1427
- _vocoder.model.reset();
1428
- _vocoder.context.reset();
1429
- _has_vocoder = false;
1430
- }
1262
+ _rn_ctx->releaseVocoder();
1431
1263
  }
1432
1264
 
1433
1265
  // isVocoderEnabled(): boolean
1434
1266
  Napi::Value LlamaContext::IsVocoderEnabled(const Napi::CallbackInfo &info) {
1435
1267
  Napi::Env env = info.Env();
1436
- return Napi::Boolean::New(env, _has_vocoder);
1268
+ return Napi::Boolean::New(env, _rn_ctx->isVocoderEnabled());
1437
1269
  }
1438
1270
 
1439
1271
  // getFormattedAudioCompletion(speaker: string|null, text: string): object
@@ -1446,31 +1278,18 @@ LlamaContext::GetFormattedAudioCompletion(const Napi::CallbackInfo &info) {
1446
1278
  }
1447
1279
  auto text = info[1].ToString().Utf8Value();
1448
1280
  auto speaker_json = info[0].IsString() ? info[0].ToString().Utf8Value() : "";
1449
- nlohmann::json speaker =
1450
- speaker_json.empty() ? nullptr : nlohmann::json::parse(speaker_json);
1451
- const tts_type type = getTTSType(env, speaker);
1452
- std::string audio_text = DEFAULT_AUDIO_TEXT;
1453
- std::string audio_data = DEFAULT_AUDIO_DATA;
1454
- if (type == OUTETTS_V0_3) {
1455
- audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"),
1456
- "<|space|>");
1457
- audio_data =
1458
- std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
1459
- audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"),
1460
- "<|space|>");
1461
- }
1462
- if (!speaker_json.empty()) {
1463
- audio_text = audio_text_from_speaker(speaker, type);
1464
- audio_data = audio_data_from_speaker(speaker, type);
1465
- }
1466
- std::string prompt = "<|im_start|>\n" + audio_text +
1467
- process_text(text, type) +
1468
- "<|text_end|>\n" + audio_data + "\n";
1281
+
1282
+ if (!_rn_ctx->tts_wrapper) {
1283
+ Napi::Error::New(env, "Vocoder not initialized")
1284
+ .ThrowAsJavaScriptException();
1285
+ return env.Undefined();
1286
+ }
1287
+
1288
+ auto result_data = _rn_ctx->tts_wrapper->getFormattedAudioCompletion(_rn_ctx, speaker_json, text);
1469
1289
  Napi::Object result = Napi::Object::New(env);
1470
- result.Set("prompt", prompt);
1471
- const char *grammar = get_tts_grammar(type);
1472
- if (grammar != nullptr) {
1473
- result.Set("grammar", grammar);
1290
+ result.Set("prompt", Napi::String::New(env, result_data.prompt));
1291
+ if (result_data.grammar) {
1292
+ result.Set("grammar", Napi::String::New(env, result_data.grammar));
1474
1293
  }
1475
1294
  return result;
1476
1295
  }
@@ -1486,37 +1305,14 @@ LlamaContext::GetAudioCompletionGuideTokens(const Napi::CallbackInfo &info) {
1486
1305
  return env.Undefined();
1487
1306
  }
1488
1307
  auto text = info[0].ToString().Utf8Value();
1489
- const tts_type type = getTTSType(env);
1490
- auto clean_text = process_text(text, type);
1491
- const std::string &delimiter =
1492
- (type == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
1493
- const llama_vocab *vocab = llama_model_get_vocab(_sess->model());
1494
-
1495
- std::vector<int32_t> result;
1496
- size_t start = 0;
1497
- size_t end = clean_text.find(delimiter);
1498
-
1499
- // first token is always a newline, as it was not previously added
1500
- result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
1501
-
1502
- while (end != std::string::npos) {
1503
- std::string current_word = clean_text.substr(start, end - start);
1504
- auto tmp = common_tokenize(vocab, current_word, false, true);
1505
- result.push_back(tmp[0]);
1506
- start = end + delimiter.length();
1507
- end = clean_text.find(delimiter, start);
1508
- }
1509
-
1510
- // Add the last part
1511
- std::string current_word = clean_text.substr(start);
1512
- auto tmp = common_tokenize(vocab, current_word, false, true);
1513
- if (tmp.size() > 0) {
1514
- result.push_back(tmp[0]);
1308
+
1309
+ if (!_rn_ctx->tts_wrapper) {
1310
+ Napi::Error::New(env, "Vocoder not initialized")
1311
+ .ThrowAsJavaScriptException();
1312
+ return env.Undefined();
1515
1313
  }
1516
-
1517
- // Add Audio End, forcing stop generation
1518
- result.push_back(common_tokenize(vocab, "<|audio_end|>", false, true)[0]);
1519
-
1314
+
1315
+ auto result = _rn_ctx->tts_wrapper->getAudioCompletionGuideTokens(_rn_ctx, text);
1520
1316
  auto tokens = Napi::Int32Array::New(env, result.size());
1521
1317
  memcpy(tokens.Data(), result.data(), result.size() * sizeof(int32_t));
1522
1318
  return tokens;
@@ -1528,6 +1324,12 @@ Napi::Value LlamaContext::DecodeAudioTokens(const Napi::CallbackInfo &info) {
1528
1324
  if (info.Length() < 1) {
1529
1325
  Napi::TypeError::New(env, "Tokens parameter is required")
1530
1326
  .ThrowAsJavaScriptException();
1327
+ return env.Undefined();
1328
+ }
1329
+ if (info[0].IsNull() || info[0].IsUndefined()) {
1330
+ Napi::TypeError::New(env, "Tokens parameter cannot be null or undefined")
1331
+ .ThrowAsJavaScriptException();
1332
+ return env.Undefined();
1531
1333
  }
1532
1334
  std::vector<int32_t> tokens;
1533
1335
  if (info[0].IsTypedArray()) {
@@ -1545,24 +1347,8 @@ Napi::Value LlamaContext::DecodeAudioTokens(const Napi::CallbackInfo &info) {
1545
1347
  .ThrowAsJavaScriptException();
1546
1348
  return env.Undefined();
1547
1349
  }
1548
- tts_type type = getTTSType(env);
1549
- if (type == UNKNOWN) {
1550
- Napi::Error::New(env, "Unsupported audio tokens")
1551
- .ThrowAsJavaScriptException();
1552
- return env.Undefined();
1553
- }
1554
- if (type == OUTETTS_V0_1 || type == OUTETTS_V0_2 || type == OUTETTS_V0_3) {
1555
- tokens.erase(
1556
- std::remove_if(tokens.begin(), tokens.end(),
1557
- [](llama_token t) { return t < 151672 || t > 155772; }),
1558
- tokens.end());
1559
- for (auto &token : tokens) {
1560
- token -= 151672;
1561
- }
1562
- }
1563
- auto worker = new DecodeAudioTokenWorker(
1564
- info, _vocoder.model.get(), _vocoder.context.get(),
1565
- _sess->params().cpuparams.n_threads, tokens);
1350
+
1351
+ auto *worker = new DecodeAudioTokenWorker(info, _rn_ctx, tokens);
1566
1352
  worker->Queue();
1567
1353
  return worker->Promise();
1568
1354
  }