@fugood/llama.node 1.1.11 → 1.2.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +18 -1
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +166 -396
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +50 -30
  27. package/src/llama.cpp/common/chat.cpp +111 -1
  28. package/src/llama.cpp/common/chat.h +3 -0
  29. package/src/llama.cpp/common/common.h +1 -1
  30. package/src/llama.cpp/common/log.cpp +53 -2
  31. package/src/llama.cpp/common/log.h +10 -4
  32. package/src/llama.cpp/common/sampling.cpp +23 -2
  33. package/src/llama.cpp/common/sampling.h +3 -1
  34. package/src/llama.cpp/common/speculative.cpp +1 -1
  35. package/src/llama.cpp/ggml/CMakeLists.txt +3 -2
  36. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
  37. package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
  38. package/src/llama.cpp/ggml/include/ggml.h +50 -1
  39. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +14 -13
  40. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  41. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +0 -6
  42. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
  44. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +4 -9
  45. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +218 -4
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  47. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +41 -37
  48. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +150 -28
  49. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +320 -73
  50. package/src/llama.cpp/include/llama.h +5 -6
  51. package/src/llama.cpp/src/llama-adapter.cpp +33 -0
  52. package/src/llama.cpp/src/llama-adapter.h +3 -0
  53. package/src/llama.cpp/src/llama-arch.cpp +27 -4
  54. package/src/llama.cpp/src/llama-arch.h +2 -0
  55. package/src/llama.cpp/src/llama-context.cpp +62 -56
  56. package/src/llama.cpp/src/llama-context.h +1 -1
  57. package/src/llama.cpp/src/llama-graph.cpp +54 -9
  58. package/src/llama.cpp/src/llama-graph.h +8 -0
  59. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  60. package/src/llama.cpp/src/llama-hparams.h +9 -3
  61. package/src/llama.cpp/src/llama-kv-cache.cpp +1 -23
  62. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  63. package/src/llama.cpp/src/llama-model.cpp +159 -1
  64. package/src/llama.cpp/src/llama-model.h +0 -1
  65. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  66. package/src/anyascii.c +0 -22223
  67. package/src/anyascii.h +0 -42
  68. package/src/tts_utils.cpp +0 -371
  69. package/src/tts_utils.h +0 -103
@@ -1,6 +1,4 @@
1
1
  #include "LlamaContext.h"
2
- #include "DecodeAudioTokenWorker.h"
3
- #include "DetokenizeWorker.h"
4
2
  #include "DisposeWorker.h"
5
3
  #include "EmbeddingWorker.h"
6
4
  #include "RerankWorker.h"
@@ -8,6 +6,8 @@
8
6
  #include "LoadSessionWorker.h"
9
7
  #include "SaveSessionWorker.h"
10
8
  #include "TokenizeWorker.h"
9
+ #include "DetokenizeWorker.h"
10
+ #include "DecodeAudioTokenWorker.h"
11
11
  #include "ggml.h"
12
12
  #include "gguf.h"
13
13
  #include "json-schema-to-grammar.h"
@@ -19,6 +19,8 @@
19
19
  #include <mutex>
20
20
  #include <queue>
21
21
 
22
+ using namespace rnllama;
23
+
22
24
  // Helper function for formatted strings (for console logs)
23
25
  template <typename... Args>
24
26
  static std::string format_string(const std::string &format, Args... args) {
@@ -175,30 +177,6 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
175
177
  exports.Set("LlamaContext", func);
176
178
  }
177
179
 
178
- const std::vector<ggml_type> kv_cache_types = {
179
- GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
180
- GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
181
- GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
182
- };
183
-
184
- static ggml_type kv_cache_type_from_str(const std::string &s) {
185
- for (const auto &type : kv_cache_types) {
186
- if (ggml_type_name(type) == s) {
187
- return type;
188
- }
189
- }
190
- throw std::runtime_error("Unsupported cache type: " + s);
191
- }
192
-
193
- static enum llama_flash_attn_type flash_attn_type_from_str(const std::string &s) {
194
- if (s == "on")
195
- return LLAMA_FLASH_ATTN_TYPE_ENABLED;
196
- if (s == "off")
197
- return LLAMA_FLASH_ATTN_TYPE_DISABLED;
198
- return LLAMA_FLASH_ATTN_TYPE_AUTO;
199
- }
200
-
201
-
202
180
  static int32_t pooling_type_from_str(const std::string &s) {
203
181
  if (s == "none")
204
182
  return LLAMA_POOLING_TYPE_NONE;
@@ -260,9 +238,9 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
260
238
  }
261
239
 
262
240
  params.cache_type_k = kv_cache_type_from_str(
263
- get_option<std::string>(options, "cache_type_k", "f16").c_str());
241
+ get_option<std::string>(options, "cache_type_k", "f16"));
264
242
  params.cache_type_v = kv_cache_type_from_str(
265
- get_option<std::string>(options, "cache_type_v", "f16").c_str());
243
+ get_option<std::string>(options, "cache_type_v", "f16"));
266
244
  params.ctx_shift = get_option<bool>(options, "ctx_shift", true);
267
245
  params.kv_unified = get_option<bool>(options, "kv_unified", false);
268
246
  params.swa_full = get_option<bool>(options, "swa_full", false);
@@ -288,59 +266,55 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
288
266
  llama_backend_init();
289
267
  llama_numa_init(params.numa);
290
268
 
291
- auto sess = std::make_shared<LlamaSession>(params);
292
-
293
- if (sess->model() == nullptr || sess->context() == nullptr) {
294
- Napi::TypeError::New(env, "Failed to load model")
295
- .ThrowAsJavaScriptException();
296
- }
297
-
298
- auto ctx = sess->context();
299
- auto model = sess->model();
300
-
301
269
  std::vector<common_adapter_lora_info> lora;
302
270
  auto lora_path = get_option<std::string>(options, "lora", "");
303
271
  auto lora_scaled = get_option<float>(options, "lora_scaled", 1.0f);
304
- if (lora_path != "") {
272
+ if (!lora_path.empty()) {
305
273
  common_adapter_lora_info la;
306
274
  la.path = lora_path;
307
275
  la.scale = lora_scaled;
308
- la.ptr = llama_adapter_lora_init(model, lora_path.c_str());
309
- if (la.ptr == nullptr) {
310
- Napi::TypeError::New(env, "Failed to load lora adapter")
311
- .ThrowAsJavaScriptException();
312
- }
313
276
  lora.push_back(la);
314
277
  }
315
278
 
316
279
  if (options.Has("lora_list") && options.Get("lora_list").IsArray()) {
317
280
  auto lora_list = options.Get("lora_list").As<Napi::Array>();
318
- if (lora_list != nullptr) {
319
- int lora_list_size = lora_list.Length();
320
- for (int i = 0; i < lora_list_size; i++) {
321
- auto lora_adapter = lora_list.Get(i).As<Napi::Object>();
322
- auto path = lora_adapter.Get("path").ToString();
323
- if (path != nullptr) {
324
- common_adapter_lora_info la;
325
- la.path = path;
326
- la.scale = lora_adapter.Get("scaled").ToNumber().FloatValue();
327
- la.ptr = llama_adapter_lora_init(model, path.Utf8Value().c_str());
328
- if (la.ptr == nullptr) {
329
- Napi::TypeError::New(env, "Failed to load lora adapter")
330
- .ThrowAsJavaScriptException();
331
- }
332
- lora.push_back(la);
333
- }
281
+ for (uint32_t i = 0; i < lora_list.Length(); i++) {
282
+ auto lora_adapter = lora_list.Get(i).As<Napi::Object>();
283
+ if (lora_adapter.Has("path")) {
284
+ common_adapter_lora_info la;
285
+ la.path = lora_adapter.Get("path").ToString();
286
+ la.scale = lora_adapter.Get("scaled").ToNumber().FloatValue();
287
+ lora.push_back(la);
334
288
  }
335
289
  }
336
290
  }
337
- common_set_adapter_lora(ctx, lora);
338
- _lora = lora;
291
+ // Use rn-llama context instead of direct session
292
+ _rn_ctx = new llama_rn_context();
293
+ if (!_rn_ctx->loadModel(params)) {
294
+ delete _rn_ctx;
295
+ _rn_ctx = nullptr;
296
+ Napi::TypeError::New(env, "Failed to load model").ThrowAsJavaScriptException();
297
+ }
339
298
 
340
- _sess = sess;
299
+ // Handle LoRA adapters through rn-llama
300
+ if (!lora.empty()) {
301
+ _rn_ctx->applyLoraAdapters(lora);
302
+ }
303
+
341
304
  _info = common_params_get_system_info(params);
305
+ }
342
306
 
343
- _templates = common_chat_templates_init(model, params.chat_template);
307
+ LlamaContext::~LlamaContext() {
308
+ // The DisposeWorker is responsible for cleanup of _rn_ctx
309
+ // If _rn_ctx is still not null here, it means disposal was not properly initiated
310
+ if (_rn_ctx) {
311
+ try {
312
+ delete _rn_ctx;
313
+ _rn_ctx = nullptr;
314
+ } catch (...) {
315
+ // Ignore errors during cleanup to avoid crashes in destructor
316
+ }
317
+ }
344
318
  }
345
319
 
346
320
  // getSystemInfo(): string
@@ -348,14 +322,6 @@ Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) {
348
322
  return Napi::String::New(info.Env(), _info);
349
323
  }
350
324
 
351
- bool validateModelChatTemplate(const struct llama_model *model,
352
- const bool use_jinja, const char *name) {
353
- const char *tmpl = llama_model_chat_template(model, name);
354
- if (tmpl == nullptr) {
355
- return false;
356
- }
357
- return common_chat_verify_template(tmpl, use_jinja);
358
- }
359
325
 
360
326
  // Store log messages for processing
361
327
  struct LogMessage {
@@ -450,8 +416,12 @@ void LlamaContext::ToggleNativeLog(const Napi::CallbackInfo &info) {
450
416
 
451
417
  // getModelInfo(): object
452
418
  Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
419
+ if (!_rn_ctx || !_rn_ctx->model) {
420
+ Napi::TypeError::New(info.Env(), "Model not loaded")
421
+ .ThrowAsJavaScriptException();
422
+ }
453
423
  char desc[1024];
454
- auto model = _sess->model();
424
+ auto model = _rn_ctx->model;
455
425
  llama_model_desc(model, desc, sizeof(desc));
456
426
 
457
427
  int count = llama_model_meta_count(model);
@@ -471,51 +441,38 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
471
441
  details.Set("size", llama_model_size(model));
472
442
 
473
443
  Napi::Object chatTemplates = Napi::Object::New(info.Env());
474
- chatTemplates.Set("llamaChat", validateModelChatTemplate(model, false, nullptr));
444
+ chatTemplates.Set("llamaChat", _rn_ctx->validateModelChatTemplate(false, nullptr));
475
445
  Napi::Object minja = Napi::Object::New(info.Env());
476
- minja.Set("default", validateModelChatTemplate(model, true, nullptr));
446
+ minja.Set("default", _rn_ctx->validateModelChatTemplate(true, nullptr));
477
447
  Napi::Object defaultCaps = Napi::Object::New(info.Env());
448
+ auto default_tmpl = _rn_ctx->templates.get()->template_default.get();
449
+ auto default_tmpl_caps = default_tmpl->original_caps();
478
450
  defaultCaps.Set(
479
451
  "tools",
480
- _templates.get()->template_default->original_caps().supports_tools);
452
+ default_tmpl_caps.supports_tools);
481
453
  defaultCaps.Set(
482
454
  "toolCalls",
483
- _templates.get()->template_default->original_caps().supports_tool_calls);
484
- defaultCaps.Set("toolResponses", _templates.get()
485
- ->template_default->original_caps()
486
- .supports_tool_responses);
455
+ default_tmpl_caps.supports_tool_calls);
456
+ defaultCaps.Set("toolResponses", default_tmpl_caps.supports_tool_responses);
487
457
  defaultCaps.Set(
488
458
  "systemRole",
489
- _templates.get()->template_default->original_caps().supports_system_role);
490
- defaultCaps.Set("parallelToolCalls", _templates.get()
491
- ->template_default->original_caps()
492
- .supports_parallel_tool_calls);
493
- defaultCaps.Set("toolCallId", _templates.get()
494
- ->template_default->original_caps()
495
- .supports_tool_call_id);
459
+ default_tmpl_caps.supports_system_role);
460
+ defaultCaps.Set("parallelToolCalls", default_tmpl_caps.supports_parallel_tool_calls);
461
+ defaultCaps.Set("toolCallId", default_tmpl_caps.supports_tool_call_id);
496
462
  minja.Set("defaultCaps", defaultCaps);
497
- minja.Set("toolUse", validateModelChatTemplate(model, true, "tool_use"));
498
- if (_templates.get()->template_tool_use) {
463
+ minja.Set("toolUse", _rn_ctx->validateModelChatTemplate(true, "tool_use"));
464
+ if (_rn_ctx->validateModelChatTemplate(true, "tool_use")) {
499
465
  Napi::Object toolUseCaps = Napi::Object::New(info.Env());
466
+ auto tool_use_tmpl = _rn_ctx->templates.get()->template_tool_use.get();
467
+ auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
500
468
  toolUseCaps.Set(
501
469
  "tools",
502
- _templates.get()->template_tool_use->original_caps().supports_tools);
503
- toolUseCaps.Set("toolCalls", _templates.get()
504
- ->template_tool_use->original_caps()
505
- .supports_tool_calls);
506
- toolUseCaps.Set("toolResponses", _templates.get()
507
- ->template_tool_use->original_caps()
508
- .supports_tool_responses);
509
- toolUseCaps.Set("systemRole", _templates.get()
510
- ->template_tool_use->original_caps()
511
- .supports_system_role);
512
- toolUseCaps.Set("parallelToolCalls",
513
- _templates.get()
514
- ->template_tool_use->original_caps()
515
- .supports_parallel_tool_calls);
516
- toolUseCaps.Set("toolCallId", _templates.get()
517
- ->template_tool_use->original_caps()
518
- .supports_tool_call_id);
470
+ tool_use_tmpl_caps.supports_tools);
471
+ toolUseCaps.Set("toolCalls", tool_use_tmpl_caps.supports_tool_calls);
472
+ toolUseCaps.Set("toolResponses", tool_use_tmpl_caps.supports_tool_responses);
473
+ toolUseCaps.Set("systemRole", tool_use_tmpl_caps.supports_system_role);
474
+ toolUseCaps.Set("parallelToolCalls", tool_use_tmpl_caps.supports_parallel_tool_calls);
475
+ toolUseCaps.Set("toolCallId", tool_use_tmpl_caps.supports_tool_call_id);
519
476
  minja.Set("toolUseCaps", toolUseCaps);
520
477
  }
521
478
  chatTemplates.Set("minja", minja);
@@ -525,76 +482,11 @@ Napi::Value LlamaContext::GetModelInfo(const Napi::CallbackInfo &info) {
525
482
 
526
483
  // Deprecated: use chatTemplates.llamaChat instead
527
484
  details.Set("isChatTemplateSupported",
528
- validateModelChatTemplate(_sess->model(), false, nullptr));
485
+ _rn_ctx->validateModelChatTemplate(false, nullptr));
529
486
  return details;
530
487
  }
531
488
 
532
- common_chat_params getFormattedChatWithJinja(
533
- const std::shared_ptr<LlamaSession> &sess,
534
- const common_chat_templates_ptr &templates, const std::string &messages,
535
- const std::string &chat_template, const std::string &json_schema,
536
- const std::string &tools, const bool &parallel_tool_calls,
537
- const std::string &tool_choice,
538
- const bool &enable_thinking,
539
- const bool &add_generation_prompt,
540
- const std::string &now_str,
541
- const std::map<std::string, std::string> &chat_template_kwargs
542
- ) {
543
- common_chat_templates_inputs inputs;
544
- inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
545
- auto useTools = !tools.empty();
546
- if (useTools) {
547
- inputs.tools = common_chat_tools_parse_oaicompat(json::parse(tools));
548
- }
549
- inputs.parallel_tool_calls = parallel_tool_calls;
550
- if (!tool_choice.empty()) {
551
- inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
552
- }
553
- if (!json_schema.empty()) {
554
- inputs.json_schema = json::parse(json_schema);
555
- }
556
- inputs.enable_thinking = enable_thinking;
557
- inputs.add_generation_prompt = add_generation_prompt;
558
-
559
- // Handle now parameter - parse timestamp or use current time
560
- if (!now_str.empty()) {
561
- try {
562
- // Try to parse as timestamp (seconds since epoch)
563
- auto timestamp = std::stoll(now_str);
564
- inputs.now = std::chrono::system_clock::from_time_t(timestamp);
565
- } catch (...) {
566
- // If parsing fails, use current time
567
- inputs.now = std::chrono::system_clock::now();
568
- }
569
- }
570
-
571
- inputs.chat_template_kwargs = chat_template_kwargs;
572
-
573
- // If chat_template is provided, create new one and use it (probably slow)
574
- if (!chat_template.empty()) {
575
- auto tmps = common_chat_templates_init(sess->model(), chat_template);
576
- return common_chat_templates_apply(tmps.get(), inputs);
577
- } else {
578
- return common_chat_templates_apply(templates.get(), inputs);
579
- }
580
- }
581
489
 
582
- std::string getFormattedChat(const struct llama_model *model,
583
- const common_chat_templates_ptr &templates,
584
- const std::string &messages,
585
- const std::string &chat_template) {
586
- common_chat_templates_inputs inputs;
587
- inputs.messages = common_chat_msgs_parse_oaicompat(json::parse(messages));
588
- inputs.use_jinja = false;
589
-
590
- // If chat_template is provided, create new one and use it (probably slow)
591
- if (!chat_template.empty()) {
592
- auto tmps = common_chat_templates_init(model, chat_template);
593
- return common_chat_templates_apply(tmps.get(), inputs).prompt;
594
- } else {
595
- return common_chat_templates_apply(templates.get(), inputs).prompt;
596
- }
597
- }
598
490
 
599
491
  // getFormattedChat(
600
492
  // messages: [{ role: string, content: string }],
@@ -604,13 +496,16 @@ std::string getFormattedChat(const struct llama_model *model,
604
496
  // ): object | string
605
497
  Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
606
498
  Napi::Env env = info.Env();
499
+ if (!_rn_ctx) {
500
+ Napi::TypeError::New(env, "Context is disposed").ThrowAsJavaScriptException();
501
+ }
607
502
  if (info.Length() < 1 || !info[0].IsArray()) {
608
503
  Napi::TypeError::New(env, "Array expected").ThrowAsJavaScriptException();
609
504
  }
610
505
  auto messages = json_stringify(info[0].As<Napi::Array>());
611
506
  auto chat_template = info[1].IsString() ? info[1].ToString().Utf8Value() : "";
612
507
 
613
- auto has_params = info.Length() >= 2;
508
+ auto has_params = info.Length() >= 3;
614
509
  auto params =
615
510
  has_params ? info[2].As<Napi::Object>() : Napi::Object::New(env);
616
511
 
@@ -659,8 +554,8 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
659
554
 
660
555
  common_chat_params chatParams;
661
556
  try {
662
- chatParams = getFormattedChatWithJinja(
663
- _sess, _templates, messages, chat_template, json_schema_str, tools_str,
557
+ chatParams = _rn_ctx->getFormattedChatWithJinja(
558
+ messages, chat_template, json_schema_str, tools_str,
664
559
  parallel_tool_calls, tool_choice, enable_thinking,
665
560
  add_generation_prompt, now_str, chat_template_kwargs);
666
561
  } catch (const nlohmann::json_abi_v3_12_0::detail::parse_error& e) {
@@ -715,7 +610,7 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
715
610
  return result;
716
611
  } else {
717
612
  auto formatted =
718
- getFormattedChat(_sess->model(), _templates, messages, chat_template);
613
+ _rn_ctx->getFormattedChat(messages, chat_template);
719
614
  return Napi::String::New(env, formatted);
720
615
  }
721
616
  }
@@ -730,7 +625,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
730
625
  if (info.Length() >= 2 && !info[1].IsFunction()) {
731
626
  Napi::TypeError::New(env, "Function expected").ThrowAsJavaScriptException();
732
627
  }
733
- if (_sess == nullptr) {
628
+ if (!_rn_ctx) {
734
629
  Napi::TypeError::New(env, "Context is disposed")
735
630
  .ThrowAsJavaScriptException();
736
631
  }
@@ -762,7 +657,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
762
657
  }
763
658
 
764
659
  // Check if multimodal is enabled when media_paths are provided
765
- if (!media_paths.empty() && !(_has_multimodal && _mtmd_ctx != nullptr)) {
660
+ if (!media_paths.empty() && !(_rn_ctx->has_multimodal && _rn_ctx->mtmd_wrapper != nullptr)) {
766
661
  Napi::Error::New(env, "Multimodal support must be enabled via "
767
662
  "initMultimodal to use media_paths")
768
663
  .ThrowAsJavaScriptException();
@@ -773,7 +668,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
773
668
  bool thinking_forced_open = get_option<bool>(options, "thinking_forced_open", false);
774
669
  std::string reasoning_format = get_option<std::string>(options, "reasoning_format", "none");
775
670
 
776
- common_params params = _sess->params();
671
+ common_params params = _rn_ctx->params;
777
672
  auto grammar_from_params = get_option<std::string>(options, "grammar", "");
778
673
  auto has_grammar_set = !grammar_from_params.empty();
779
674
  if (has_grammar_set) {
@@ -806,7 +701,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
806
701
  for (size_t i = 0; i < preserved_tokens.Length(); i++) {
807
702
  auto token = preserved_tokens.Get(i).ToString().Utf8Value();
808
703
  auto ids =
809
- common_tokenize(_sess->context(), token, /* add_special= */ false,
704
+ common_tokenize(_rn_ctx->ctx, token, /* add_special= */ false,
810
705
  /* parse_special= */ true);
811
706
  if (ids.size() == 1) {
812
707
  params.sampling.preserved_tokens.insert(ids[0]);
@@ -826,7 +721,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
826
721
 
827
722
  if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
828
723
  auto ids =
829
- common_tokenize(_sess->context(), word, /* add_special= */ false,
724
+ common_tokenize(_rn_ctx->ctx, word, /* add_special= */ false,
830
725
  /* parse_special= */ true);
831
726
  if (ids.size() == 1) {
832
727
  auto token = ids[0];
@@ -897,8 +792,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
897
792
  common_chat_params chatParams;
898
793
 
899
794
  try {
900
- chatParams = getFormattedChatWithJinja(
901
- _sess, _templates, json_stringify(messages), chat_template,
795
+ chatParams = _rn_ctx->getFormattedChatWithJinja(
796
+ json_stringify(messages), chat_template,
902
797
  json_schema_str, tools_str, parallel_tool_calls, tool_choice, enable_thinking,
903
798
  add_generation_prompt, now_str, chat_template_kwargs);
904
799
  } catch (const std::exception &e) {
@@ -913,7 +808,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
913
808
 
914
809
  for (const auto &token : chatParams.preserved_tokens) {
915
810
  auto ids =
916
- common_tokenize(_sess->context(), token, /* add_special= */ false,
811
+ common_tokenize(_rn_ctx->ctx, token, /* add_special= */ false,
917
812
  /* parse_special= */ true);
918
813
  if (ids.size() == 1) {
919
814
  params.sampling.preserved_tokens.insert(ids[0]);
@@ -934,8 +829,8 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
934
829
  stop_words.push_back(stop);
935
830
  }
936
831
  } else {
937
- auto formatted = getFormattedChat(
938
- _sess->model(), _templates, json_stringify(messages), chat_template);
832
+ auto formatted = _rn_ctx->getFormattedChat(
833
+ json_stringify(messages), chat_template);
939
834
  params.prompt = formatted;
940
835
  }
941
836
  } else {
@@ -989,6 +884,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
989
884
  params.n_keep = get_option<int32_t>(options, "n_keep", 0);
990
885
  params.sampling.seed =
991
886
  get_option<int32_t>(options, "seed", LLAMA_DEFAULT_SEED);
887
+ params.sampling.n_probs = get_option<int32_t>(options, "n_probs", 0);
992
888
 
993
889
  // guide_tokens
994
890
  std::vector<llama_token> guide_tokens;
@@ -1023,9 +919,9 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
1023
919
  }
1024
920
 
1025
921
  auto *worker =
1026
- new LlamaCompletionWorker(info, _sess, callback, params, stop_words,
922
+ new LlamaCompletionWorker(info, _rn_ctx, callback, params, stop_words,
1027
923
  chat_format, thinking_forced_open, reasoning_format, media_paths, guide_tokens,
1028
- _has_vocoder, _tts_type, prefill_text);
924
+ _rn_ctx->has_vocoder, _rn_ctx->tts_wrapper ? _rn_ctx->tts_wrapper->type : rnllama::UNKNOWN, prefill_text);
1029
925
  worker->Queue();
1030
926
  _wip = worker;
1031
927
  worker->OnComplete([this]() { _wip = nullptr; });
@@ -1039,25 +935,28 @@ void LlamaContext::StopCompletion(const Napi::CallbackInfo &info) {
1039
935
  }
1040
936
  }
1041
937
 
1042
- // tokenize(text: string): Promise<TokenizeResult>
938
+ // tokenize(text: string, ): Promise<TokenizeResult>
1043
939
  Napi::Value LlamaContext::Tokenize(const Napi::CallbackInfo &info) {
1044
940
  Napi::Env env = info.Env();
1045
941
  if (info.Length() < 1 || !info[0].IsString()) {
1046
942
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1047
943
  }
1048
- if (_sess == nullptr) {
944
+ if (!_rn_ctx) {
1049
945
  Napi::TypeError::New(env, "Context is disposed")
1050
946
  .ThrowAsJavaScriptException();
1051
947
  }
1052
948
  auto text = info[0].ToString().Utf8Value();
1053
949
  std::vector<std::string> media_paths;
950
+
1054
951
  if (info.Length() >= 2 && info[1].IsArray()) {
952
+ // Direct array format: tokenize(text, [media_paths])
1055
953
  auto media_paths_array = info[1].As<Napi::Array>();
1056
954
  for (size_t i = 0; i < media_paths_array.Length(); i++) {
1057
955
  media_paths.push_back(media_paths_array.Get(i).ToString().Utf8Value());
1058
956
  }
1059
957
  }
1060
- auto *worker = new TokenizeWorker(info, _sess, text, media_paths);
958
+
959
+ auto *worker = new TokenizeWorker(info, _rn_ctx, text, media_paths);
1061
960
  worker->Queue();
1062
961
  return worker->Promise();
1063
962
  }
@@ -1068,7 +967,7 @@ Napi::Value LlamaContext::Detokenize(const Napi::CallbackInfo &info) {
1068
967
  if (info.Length() < 1 || !info[0].IsArray()) {
1069
968
  Napi::TypeError::New(env, "Array expected").ThrowAsJavaScriptException();
1070
969
  }
1071
- if (_sess == nullptr) {
970
+ if (!_rn_ctx) {
1072
971
  Napi::TypeError::New(env, "Context is disposed")
1073
972
  .ThrowAsJavaScriptException();
1074
973
  }
@@ -1077,7 +976,8 @@ Napi::Value LlamaContext::Detokenize(const Napi::CallbackInfo &info) {
1077
976
  for (size_t i = 0; i < tokens.Length(); i++) {
1078
977
  token_ids.push_back(tokens.Get(i).ToNumber().Int32Value());
1079
978
  }
1080
- auto *worker = new DetokenizeWorker(info, _sess, token_ids);
979
+
980
+ auto *worker = new DetokenizeWorker(info, _rn_ctx, token_ids);
1081
981
  worker->Queue();
1082
982
  return worker->Promise();
1083
983
  }
@@ -1088,7 +988,7 @@ Napi::Value LlamaContext::Embedding(const Napi::CallbackInfo &info) {
1088
988
  if (info.Length() < 1 || !info[0].IsString()) {
1089
989
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1090
990
  }
1091
- if (_sess == nullptr) {
991
+ if (!_rn_ctx) {
1092
992
  Napi::TypeError::New(env, "Context is disposed")
1093
993
  .ThrowAsJavaScriptException();
1094
994
  }
@@ -1101,7 +1001,7 @@ Napi::Value LlamaContext::Embedding(const Napi::CallbackInfo &info) {
1101
1001
  embdParams.embedding = true;
1102
1002
  embdParams.embd_normalize = get_option<int32_t>(options, "embd_normalize", 2);
1103
1003
  auto text = info[0].ToString().Utf8Value();
1104
- auto *worker = new EmbeddingWorker(info, _sess, text, embdParams);
1004
+ auto *worker = new EmbeddingWorker(info, _rn_ctx, text, embdParams);
1105
1005
  worker->Queue();
1106
1006
  return worker->Promise();
1107
1007
  }
@@ -1112,7 +1012,7 @@ Napi::Value LlamaContext::Rerank(const Napi::CallbackInfo &info) {
1112
1012
  if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray()) {
1113
1013
  Napi::TypeError::New(env, "Query string and documents array expected").ThrowAsJavaScriptException();
1114
1014
  }
1115
- if (_sess == nullptr) {
1015
+ if (!_rn_ctx) {
1116
1016
  Napi::TypeError::New(env, "Context is disposed")
1117
1017
  .ThrowAsJavaScriptException();
1118
1018
  }
@@ -1135,7 +1035,7 @@ Napi::Value LlamaContext::Rerank(const Napi::CallbackInfo &info) {
1135
1035
  rerankParams.embedding = true;
1136
1036
  rerankParams.embd_normalize = get_option<int32_t>(options, "normalize", -1);
1137
1037
 
1138
- auto *worker = new RerankWorker(info, _sess, query, documents, rerankParams);
1038
+ auto *worker = new RerankWorker(info, _rn_ctx, query, documents, rerankParams);
1139
1039
  worker->Queue();
1140
1040
  return worker->Promise();
1141
1041
  }
@@ -1146,17 +1046,17 @@ Napi::Value LlamaContext::SaveSession(const Napi::CallbackInfo &info) {
1146
1046
  if (info.Length() < 1 || !info[0].IsString()) {
1147
1047
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1148
1048
  }
1149
- if (_sess == nullptr) {
1049
+ if (!_rn_ctx) {
1150
1050
  Napi::TypeError::New(env, "Context is disposed")
1151
1051
  .ThrowAsJavaScriptException();
1152
1052
  }
1153
1053
  #ifdef GGML_USE_VULKAN
1154
- if (_sess->params().n_gpu_layers > 0) {
1054
+ if (_rn_ctx->params.n_gpu_layers > 0) {
1155
1055
  Napi::TypeError::New(env, "Vulkan cannot save session")
1156
1056
  .ThrowAsJavaScriptException();
1157
1057
  }
1158
1058
  #endif
1159
- auto *worker = new SaveSessionWorker(info, _sess);
1059
+ auto *worker = new SaveSessionWorker(info, _rn_ctx);
1160
1060
  worker->Queue();
1161
1061
  return worker->Promise();
1162
1062
  }
@@ -1167,17 +1067,17 @@ Napi::Value LlamaContext::LoadSession(const Napi::CallbackInfo &info) {
1167
1067
  if (info.Length() < 1 || !info[0].IsString()) {
1168
1068
  Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
1169
1069
  }
1170
- if (_sess == nullptr) {
1070
+ if (!_rn_ctx) {
1171
1071
  Napi::TypeError::New(env, "Context is disposed")
1172
1072
  .ThrowAsJavaScriptException();
1173
1073
  }
1174
1074
  #ifdef GGML_USE_VULKAN
1175
- if (_sess->params().n_gpu_layers > 0) {
1075
+ if (_rn_ctx->params.n_gpu_layers > 0) {
1176
1076
  Napi::TypeError::New(env, "Vulkan cannot load session")
1177
1077
  .ThrowAsJavaScriptException();
1178
1078
  }
1179
1079
  #endif
1180
- auto *worker = new LoadSessionWorker(info, _sess);
1080
+ auto *worker = new LoadSessionWorker(info, _rn_ctx);
1181
1081
  worker->Queue();
1182
1082
  return worker->Promise();
1183
1083
  }
@@ -1185,6 +1085,9 @@ Napi::Value LlamaContext::LoadSession(const Napi::CallbackInfo &info) {
1185
1085
  // applyLoraAdapters(lora_adapters: [{ path: string, scaled: number }]): void
1186
1086
  void LlamaContext::ApplyLoraAdapters(const Napi::CallbackInfo &info) {
1187
1087
  Napi::Env env = info.Env();
1088
+ if (!_rn_ctx) {
1089
+ Napi::TypeError::New(env, "Context is disposed").ThrowAsJavaScriptException();
1090
+ }
1188
1091
  std::vector<common_adapter_lora_info> lora;
1189
1092
  auto lora_adapters = info[0].As<Napi::Array>();
1190
1093
  for (size_t i = 0; i < lora_adapters.Length(); i++) {
@@ -1194,21 +1097,16 @@ void LlamaContext::ApplyLoraAdapters(const Napi::CallbackInfo &info) {
1194
1097
  common_adapter_lora_info la;
1195
1098
  la.path = path;
1196
1099
  la.scale = scaled;
1197
- la.ptr = llama_adapter_lora_init(_sess->model(), path.c_str());
1198
- if (la.ptr == nullptr) {
1199
- Napi::TypeError::New(env, "Failed to load lora adapter")
1200
- .ThrowAsJavaScriptException();
1201
- }
1202
1100
  lora.push_back(la);
1203
1101
  }
1204
- common_set_adapter_lora(_sess->context(), lora);
1205
- _lora = lora;
1102
+ _rn_ctx->applyLoraAdapters(lora);
1206
1103
  }
1207
1104
 
1208
1105
  // removeLoraAdapters(): void
1209
1106
  void LlamaContext::RemoveLoraAdapters(const Napi::CallbackInfo &info) {
1210
- _lora.clear();
1211
- common_set_adapter_lora(_sess->context(), _lora);
1107
+ if (_rn_ctx) {
1108
+ _rn_ctx->removeLoraAdapters();
1109
+ }
1212
1110
  }
1213
1111
 
1214
1112
  // getLoadedLoraAdapters(): Promise<{ count, lora_adapters: [{ path: string,
@@ -1216,11 +1114,15 @@ void LlamaContext::RemoveLoraAdapters(const Napi::CallbackInfo &info) {
1216
1114
  Napi::Value
1217
1115
  LlamaContext::GetLoadedLoraAdapters(const Napi::CallbackInfo &info) {
1218
1116
  Napi::Env env = info.Env();
1219
- Napi::Array lora_adapters = Napi::Array::New(env, _lora.size());
1220
- for (size_t i = 0; i < _lora.size(); i++) {
1117
+ if (!_rn_ctx) {
1118
+ Napi::TypeError::New(env, "Context is disposed").ThrowAsJavaScriptException();
1119
+ }
1120
+ auto lora = _rn_ctx->getLoadedLoraAdapters();
1121
+ Napi::Array lora_adapters = Napi::Array::New(env, lora.size());
1122
+ for (size_t i = 0; i < lora.size(); i++) {
1221
1123
  Napi::Object lora_adapter = Napi::Object::New(env);
1222
- lora_adapter.Set("path", _lora[i].path);
1223
- lora_adapter.Set("scaled", _lora[i].scale);
1124
+ lora_adapter.Set("path", lora[i].path);
1125
+ lora_adapter.Set("scaled", lora[i].scale);
1224
1126
  lora_adapters.Set(i, lora_adapter);
1225
1127
  }
1226
1128
  return lora_adapters;
@@ -1233,18 +1135,13 @@ Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) {
1233
1135
  _wip->SetStop();
1234
1136
  }
1235
1137
 
1236
- if (_sess == nullptr) {
1138
+ if (_rn_ctx == nullptr) {
1237
1139
  auto promise = Napi::Promise::Deferred(env);
1238
1140
  promise.Resolve(env.Undefined());
1239
1141
  return promise.Promise();
1240
1142
  }
1241
1143
 
1242
- // Clear the mtmd context reference in the session
1243
- if (_mtmd_ctx != nullptr) {
1244
- _sess->set_mtmd_ctx(nullptr);
1245
- }
1246
-
1247
- auto *worker = new DisposeWorker(info, std::move(_sess));
1144
+ auto *worker = new DisposeWorker(info, _rn_ctx, &_rn_ctx);
1248
1145
  worker->Queue();
1249
1146
  return worker->Promise();
1250
1147
  }
@@ -1258,13 +1155,6 @@ extern "C" void cleanup_logging() {
1258
1155
  }
1259
1156
  }
1260
1157
 
1261
- LlamaContext::~LlamaContext() {
1262
- if (_mtmd_ctx != nullptr) {
1263
- mtmd_free(_mtmd_ctx);
1264
- _mtmd_ctx = nullptr;
1265
- _has_multimodal = false;
1266
- }
1267
- }
1268
1158
 
1269
1159
  // initMultimodal(options: { path: string, use_gpu?: boolean }): boolean
1270
1160
  Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
@@ -1286,50 +1176,20 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1286
1176
 
1287
1177
  console_log(env, "Initializing multimodal with mmproj path: " + mmproj_path);
1288
1178
 
1289
- auto model = _sess->model();
1290
- auto ctx = _sess->context();
1291
- if (model == nullptr) {
1179
+ if (_rn_ctx->model == nullptr) {
1292
1180
  Napi::Error::New(env, "Model not loaded").ThrowAsJavaScriptException();
1293
1181
  return Napi::Boolean::New(env, false);
1294
1182
  }
1295
1183
 
1296
- if (_mtmd_ctx != nullptr) {
1297
- mtmd_free(_mtmd_ctx);
1298
- _mtmd_ctx = nullptr;
1299
- _has_multimodal = false;
1300
- }
1301
-
1302
- // Initialize mtmd context
1303
- mtmd_context_params mtmd_params = mtmd_context_params_default();
1304
- mtmd_params.use_gpu = use_gpu;
1305
- mtmd_params.print_timings = false;
1306
- mtmd_params.n_threads = _sess->params().cpuparams.n_threads;
1307
- mtmd_params.verbosity = (ggml_log_level)GGML_LOG_LEVEL_INFO;
1308
-
1309
- console_log(env, format_string(
1310
- "Initializing mtmd context with threads=%d, use_gpu=%d",
1311
- mtmd_params.n_threads, mtmd_params.use_gpu ? 1 : 0));
1312
-
1313
- _mtmd_ctx = mtmd_init_from_file(mmproj_path.c_str(), model, mtmd_params);
1314
- if (_mtmd_ctx == nullptr) {
1184
+ // Disable ctx_shift before initializing multimodal
1185
+ _rn_ctx->params.ctx_shift = false;
1186
+ bool result = _rn_ctx->initMultimodal(mmproj_path, use_gpu);
1187
+ if (!result) {
1315
1188
  Napi::Error::New(env, "Failed to initialize multimodal context")
1316
1189
  .ThrowAsJavaScriptException();
1317
1190
  return Napi::Boolean::New(env, false);
1318
1191
  }
1319
1192
 
1320
- _has_multimodal = true;
1321
-
1322
- // Share the mtmd context with the session
1323
- _sess->set_mtmd_ctx(_mtmd_ctx);
1324
-
1325
- // Check if the model uses M-RoPE or non-causal attention
1326
- bool uses_mrope = mtmd_decode_use_mrope(_mtmd_ctx);
1327
- bool uses_non_causal = mtmd_decode_use_non_causal(_mtmd_ctx);
1328
- console_log(
1329
- env, format_string(
1330
- "Model multimodal properties: uses_mrope=%d, uses_non_causal=%d",
1331
- uses_mrope ? 1 : 0, uses_non_causal ? 1 : 0));
1332
-
1333
1193
  console_log(env, "Multimodal context initialized successfully with mmproj: " +
1334
1194
  mmproj_path);
1335
1195
  return Napi::Boolean::New(env, true);
@@ -1337,8 +1197,7 @@ Napi::Value LlamaContext::InitMultimodal(const Napi::CallbackInfo &info) {
1337
1197
 
1338
1198
  // isMultimodalEnabled(): boolean
1339
1199
  Napi::Value LlamaContext::IsMultimodalEnabled(const Napi::CallbackInfo &info) {
1340
- return Napi::Boolean::New(info.Env(),
1341
- _has_multimodal && _mtmd_ctx != nullptr);
1200
+ return Napi::Boolean::New(info.Env(), _rn_ctx->isMultimodalEnabled());
1342
1201
  }
1343
1202
 
1344
1203
  // getMultimodalSupport(): Promise<{ vision: boolean, audio: boolean }>
@@ -1346,10 +1205,10 @@ Napi::Value LlamaContext::GetMultimodalSupport(const Napi::CallbackInfo &info) {
1346
1205
  Napi::Env env = info.Env();
1347
1206
  auto result = Napi::Object::New(env);
1348
1207
 
1349
- if (_has_multimodal && _mtmd_ctx != nullptr) {
1208
+ if (_rn_ctx->isMultimodalEnabled()) {
1350
1209
  result.Set("vision",
1351
- Napi::Boolean::New(env, mtmd_support_vision(_mtmd_ctx)));
1352
- result.Set("audio", Napi::Boolean::New(env, mtmd_support_audio(_mtmd_ctx)));
1210
+ Napi::Boolean::New(env, _rn_ctx->isMultimodalSupportVision()));
1211
+ result.Set("audio", Napi::Boolean::New(env, _rn_ctx->isMultimodalSupportAudio()));
1353
1212
  } else {
1354
1213
  result.Set("vision", Napi::Boolean::New(env, false));
1355
1214
  result.Set("audio", Napi::Boolean::New(env, false));
@@ -1360,42 +1219,14 @@ Napi::Value LlamaContext::GetMultimodalSupport(const Napi::CallbackInfo &info) {
1360
1219
 
1361
1220
  // releaseMultimodal(): void
1362
1221
  void LlamaContext::ReleaseMultimodal(const Napi::CallbackInfo &info) {
1363
- if (_mtmd_ctx != nullptr) {
1364
- // Clear the mtmd context reference in the session
1365
- if (_sess != nullptr) {
1366
- _sess->set_mtmd_ctx(nullptr);
1367
- }
1368
-
1369
- // Free the mtmd context
1370
- mtmd_free(_mtmd_ctx);
1371
- _mtmd_ctx = nullptr;
1372
- _has_multimodal = false;
1373
- }
1222
+ _rn_ctx->releaseMultimodal();
1374
1223
  }
1375
1224
 
1376
- tts_type LlamaContext::getTTSType(Napi::Env env, nlohmann::json speaker) {
1377
- if (speaker.is_object() && speaker.contains("version")) {
1378
- std::string version = speaker["version"].get<std::string>();
1379
- if (version == "0.2") {
1380
- return OUTETTS_V0_2;
1381
- } else if (version == "0.3") {
1382
- return OUTETTS_V0_3;
1383
- } else {
1384
- Napi::Error::New(env, format_string("Unsupported speaker version '%s'\n",
1385
- version.c_str()))
1386
- .ThrowAsJavaScriptException();
1387
- return UNKNOWN;
1388
- }
1225
+ rnllama::tts_type LlamaContext::getTTSType(Napi::Env env, nlohmann::json speaker) {
1226
+ if (_rn_ctx->tts_wrapper) {
1227
+ return _rn_ctx->tts_wrapper->getTTSType(_rn_ctx, speaker);
1389
1228
  }
1390
- if (_tts_type != UNKNOWN) {
1391
- return _tts_type;
1392
- }
1393
- const char *chat_template =
1394
- llama_model_chat_template(_sess->model(), nullptr);
1395
- if (chat_template && std::string(chat_template) == "outetts-0.3") {
1396
- return OUTETTS_V0_3;
1397
- }
1398
- return OUTETTS_V0_2;
1229
+ return rnllama::UNKNOWN;
1399
1230
  }
1400
1231
 
1401
1232
  // initVocoder(params?: object): boolean
@@ -1407,49 +1238,34 @@ Napi::Value LlamaContext::InitVocoder(const Napi::CallbackInfo &info) {
1407
1238
  }
1408
1239
  auto options = info[0].As<Napi::Object>();
1409
1240
  auto vocoder_path = options.Get("path").ToString().Utf8Value();
1410
- auto n_batch = get_option<int32_t>(options, "n_batch", _sess->params().n_batch);
1241
+ auto n_batch = get_option<int32_t>(options, "n_batch", _rn_ctx->params.n_batch);
1411
1242
  if (vocoder_path.empty()) {
1412
1243
  Napi::TypeError::New(env, "vocoder path is required")
1413
1244
  .ThrowAsJavaScriptException();
1414
1245
  }
1415
- if (_has_vocoder) {
1246
+ if (_rn_ctx->has_vocoder) {
1416
1247
  Napi::Error::New(env, "Vocoder already initialized")
1417
1248
  .ThrowAsJavaScriptException();
1418
1249
  return Napi::Boolean::New(env, false);
1419
1250
  }
1420
- _tts_type = getTTSType(env);
1421
- _vocoder.params = _sess->params();
1422
- _vocoder.params.warmup = false;
1423
- _vocoder.params.model.path = vocoder_path;
1424
- _vocoder.params.embedding = true;
1425
- _vocoder.params.ctx_shift = false;
1426
- _vocoder.params.n_batch = n_batch;
1427
- _vocoder.params.n_ubatch = _vocoder.params.n_batch;
1428
- common_init_result result = common_init_from_params(_vocoder.params);
1429
- if (result.model == nullptr || result.context == nullptr) {
1251
+ bool result = _rn_ctx->initVocoder(vocoder_path, n_batch);
1252
+ if (!result) {
1430
1253
  Napi::Error::New(env, "Failed to initialize vocoder")
1431
1254
  .ThrowAsJavaScriptException();
1432
1255
  return Napi::Boolean::New(env, false);
1433
1256
  }
1434
- _vocoder.model = std::move(result.model);
1435
- _vocoder.context = std::move(result.context);
1436
- _has_vocoder = true;
1437
1257
  return Napi::Boolean::New(env, true);
1438
1258
  }
1439
1259
 
1440
1260
  // releaseVocoder(): void
1441
1261
  void LlamaContext::ReleaseVocoder(const Napi::CallbackInfo &info) {
1442
- if (_has_vocoder) {
1443
- _vocoder.model.reset();
1444
- _vocoder.context.reset();
1445
- _has_vocoder = false;
1446
- }
1262
+ _rn_ctx->releaseVocoder();
1447
1263
  }
1448
1264
 
1449
1265
  // isVocoderEnabled(): boolean
1450
1266
  Napi::Value LlamaContext::IsVocoderEnabled(const Napi::CallbackInfo &info) {
1451
1267
  Napi::Env env = info.Env();
1452
- return Napi::Boolean::New(env, _has_vocoder);
1268
+ return Napi::Boolean::New(env, _rn_ctx->isVocoderEnabled());
1453
1269
  }
1454
1270
 
1455
1271
  // getFormattedAudioCompletion(speaker: string|null, text: string): object
@@ -1462,31 +1278,18 @@ LlamaContext::GetFormattedAudioCompletion(const Napi::CallbackInfo &info) {
1462
1278
  }
1463
1279
  auto text = info[1].ToString().Utf8Value();
1464
1280
  auto speaker_json = info[0].IsString() ? info[0].ToString().Utf8Value() : "";
1465
- nlohmann::json speaker =
1466
- speaker_json.empty() ? nullptr : nlohmann::json::parse(speaker_json);
1467
- const tts_type type = getTTSType(env, speaker);
1468
- std::string audio_text = DEFAULT_AUDIO_TEXT;
1469
- std::string audio_data = DEFAULT_AUDIO_DATA;
1470
- if (type == OUTETTS_V0_3) {
1471
- audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"),
1472
- "<|space|>");
1473
- audio_data =
1474
- std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
1475
- audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"),
1476
- "<|space|>");
1477
- }
1478
- if (!speaker_json.empty()) {
1479
- audio_text = audio_text_from_speaker(speaker, type);
1480
- audio_data = audio_data_from_speaker(speaker, type);
1481
- }
1482
- std::string prompt = "<|im_start|>\n" + audio_text +
1483
- process_text(text, type) +
1484
- "<|text_end|>\n" + audio_data + "\n";
1281
+
1282
+ if (!_rn_ctx->tts_wrapper) {
1283
+ Napi::Error::New(env, "Vocoder not initialized")
1284
+ .ThrowAsJavaScriptException();
1285
+ return env.Undefined();
1286
+ }
1287
+
1288
+ auto result_data = _rn_ctx->tts_wrapper->getFormattedAudioCompletion(_rn_ctx, speaker_json, text);
1485
1289
  Napi::Object result = Napi::Object::New(env);
1486
- result.Set("prompt", prompt);
1487
- const char *grammar = get_tts_grammar(type);
1488
- if (grammar != nullptr) {
1489
- result.Set("grammar", grammar);
1290
+ result.Set("prompt", Napi::String::New(env, result_data.prompt));
1291
+ if (result_data.grammar) {
1292
+ result.Set("grammar", Napi::String::New(env, result_data.grammar));
1490
1293
  }
1491
1294
  return result;
1492
1295
  }
@@ -1502,37 +1305,14 @@ LlamaContext::GetAudioCompletionGuideTokens(const Napi::CallbackInfo &info) {
1502
1305
  return env.Undefined();
1503
1306
  }
1504
1307
  auto text = info[0].ToString().Utf8Value();
1505
- const tts_type type = getTTSType(env);
1506
- auto clean_text = process_text(text, type);
1507
- const std::string &delimiter =
1508
- (type == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
1509
- const llama_vocab *vocab = llama_model_get_vocab(_sess->model());
1510
-
1511
- std::vector<int32_t> result;
1512
- size_t start = 0;
1513
- size_t end = clean_text.find(delimiter);
1514
-
1515
- // first token is always a newline, as it was not previously added
1516
- result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
1517
-
1518
- while (end != std::string::npos) {
1519
- std::string current_word = clean_text.substr(start, end - start);
1520
- auto tmp = common_tokenize(vocab, current_word, false, true);
1521
- result.push_back(tmp[0]);
1522
- start = end + delimiter.length();
1523
- end = clean_text.find(delimiter, start);
1524
- }
1525
-
1526
- // Add the last part
1527
- std::string current_word = clean_text.substr(start);
1528
- auto tmp = common_tokenize(vocab, current_word, false, true);
1529
- if (tmp.size() > 0) {
1530
- result.push_back(tmp[0]);
1308
+
1309
+ if (!_rn_ctx->tts_wrapper) {
1310
+ Napi::Error::New(env, "Vocoder not initialized")
1311
+ .ThrowAsJavaScriptException();
1312
+ return env.Undefined();
1531
1313
  }
1532
-
1533
- // Add Audio End, forcing stop generation
1534
- result.push_back(common_tokenize(vocab, "<|audio_end|>", false, true)[0]);
1535
-
1314
+
1315
+ auto result = _rn_ctx->tts_wrapper->getAudioCompletionGuideTokens(_rn_ctx, text);
1536
1316
  auto tokens = Napi::Int32Array::New(env, result.size());
1537
1317
  memcpy(tokens.Data(), result.data(), result.size() * sizeof(int32_t));
1538
1318
  return tokens;
@@ -1544,6 +1324,12 @@ Napi::Value LlamaContext::DecodeAudioTokens(const Napi::CallbackInfo &info) {
1544
1324
  if (info.Length() < 1) {
1545
1325
  Napi::TypeError::New(env, "Tokens parameter is required")
1546
1326
  .ThrowAsJavaScriptException();
1327
+ return env.Undefined();
1328
+ }
1329
+ if (info[0].IsNull() || info[0].IsUndefined()) {
1330
+ Napi::TypeError::New(env, "Tokens parameter cannot be null or undefined")
1331
+ .ThrowAsJavaScriptException();
1332
+ return env.Undefined();
1547
1333
  }
1548
1334
  std::vector<int32_t> tokens;
1549
1335
  if (info[0].IsTypedArray()) {
@@ -1561,24 +1347,8 @@ Napi::Value LlamaContext::DecodeAudioTokens(const Napi::CallbackInfo &info) {
1561
1347
  .ThrowAsJavaScriptException();
1562
1348
  return env.Undefined();
1563
1349
  }
1564
- tts_type type = getTTSType(env);
1565
- if (type == UNKNOWN) {
1566
- Napi::Error::New(env, "Unsupported audio tokens")
1567
- .ThrowAsJavaScriptException();
1568
- return env.Undefined();
1569
- }
1570
- if (type == OUTETTS_V0_1 || type == OUTETTS_V0_2 || type == OUTETTS_V0_3) {
1571
- tokens.erase(
1572
- std::remove_if(tokens.begin(), tokens.end(),
1573
- [](llama_token t) { return t < 151672 || t > 155772; }),
1574
- tokens.end());
1575
- for (auto &token : tokens) {
1576
- token -= 151672;
1577
- }
1578
- }
1579
- auto worker = new DecodeAudioTokenWorker(
1580
- info, _vocoder.model.get(), _vocoder.context.get(),
1581
- _sess->params().cpuparams.n_threads, tokens);
1350
+
1351
+ auto *worker = new DecodeAudioTokenWorker(info, _rn_ctx, tokens);
1582
1352
  worker->Queue();
1583
1353
  return worker->Promise();
1584
1354
  }