@fugood/llama.node 1.1.11 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +18 -1
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +166 -396
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +50 -30
  27. package/src/llama.cpp/common/chat.cpp +250 -1
  28. package/src/llama.cpp/common/chat.h +4 -0
  29. package/src/llama.cpp/common/common.h +1 -1
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +21 -1
  31. package/src/llama.cpp/common/log.cpp +53 -2
  32. package/src/llama.cpp/common/log.h +10 -4
  33. package/src/llama.cpp/common/sampling.cpp +23 -2
  34. package/src/llama.cpp/common/sampling.h +3 -1
  35. package/src/llama.cpp/common/speculative.cpp +1 -1
  36. package/src/llama.cpp/ggml/CMakeLists.txt +3 -2
  37. package/src/llama.cpp/ggml/include/ggml-backend.h +15 -0
  38. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -1
  39. package/src/llama.cpp/ggml/include/ggml-metal.h +0 -6
  40. package/src/llama.cpp/ggml/include/ggml.h +56 -2
  41. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +21 -14
  42. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  43. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +57 -59
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +6 -7
  45. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +25 -38
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -4
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +4 -12
  48. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +379 -4
  49. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  50. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +41 -37
  51. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +150 -28
  52. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +320 -73
  53. package/src/llama.cpp/include/llama.h +5 -6
  54. package/src/llama.cpp/src/llama-adapter.cpp +33 -0
  55. package/src/llama.cpp/src/llama-adapter.h +3 -0
  56. package/src/llama.cpp/src/llama-arch.cpp +28 -4
  57. package/src/llama.cpp/src/llama-arch.h +3 -0
  58. package/src/llama.cpp/src/llama-context.cpp +65 -57
  59. package/src/llama.cpp/src/llama-context.h +1 -1
  60. package/src/llama.cpp/src/llama-graph.cpp +57 -11
  61. package/src/llama.cpp/src/llama-graph.h +8 -0
  62. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  63. package/src/llama.cpp/src/llama-hparams.h +10 -3
  64. package/src/llama.cpp/src/llama-kv-cache.cpp +56 -38
  65. package/src/llama.cpp/src/llama-kv-cache.h +9 -0
  66. package/src/llama.cpp/src/llama-model.cpp +217 -97
  67. package/src/llama.cpp/src/llama-model.h +0 -1
  68. package/src/llama.cpp/src/llama-quant.cpp +3 -3
  69. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  70. package/src/llama.cpp/src/llama.cpp +53 -10
  71. package/src/anyascii.c +0 -22223
  72. package/src/anyascii.h +0 -42
  73. package/src/tts_utils.cpp +0 -371
  74. package/src/tts_utils.h +0 -103
@@ -1,31 +1,36 @@
1
1
  #include "LlamaCompletionWorker.h"
2
2
  #include "LlamaContext.h"
3
+ #include "rn-llama/rn-completion.h"
3
4
  #include <limits>
4
5
 
5
- size_t findStoppingStrings(const std::string &text,
6
- const size_t last_token_size,
7
- const std::vector<std::string> &stop_words) {
8
- size_t stop_pos = std::string::npos;
9
-
10
- for (const std::string &word : stop_words) {
11
- size_t pos;
12
-
13
- const size_t tmp = word.size() + last_token_size;
14
- const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
15
-
16
- pos = text.find(word, from_pos);
17
-
18
- if (pos != std::string::npos &&
19
- (stop_pos == std::string::npos || pos < stop_pos)) {
20
- stop_pos = pos;
6
+ // Helper function to convert token probabilities to JavaScript format
7
+ Napi::Array TokenProbsToArray(Napi::Env env, llama_context* ctx, const std::vector<rnllama::completion_token_output>& probs) {
8
+ Napi::Array result = Napi::Array::New(env);
9
+ for (size_t i = 0; i < probs.size(); i++) {
10
+ const auto &prob = probs[i];
11
+ Napi::Object token_obj = Napi::Object::New(env);
12
+
13
+ std::string token_str = common_token_to_piece(ctx, prob.tok);
14
+ token_obj.Set("content", Napi::String::New(env, token_str));
15
+
16
+ Napi::Array token_probs = Napi::Array::New(env);
17
+ for (size_t j = 0; j < prob.probs.size(); j++) {
18
+ const auto &p = prob.probs[j];
19
+ Napi::Object prob_obj = Napi::Object::New(env);
20
+ std::string tok_str = common_token_to_piece(ctx, p.tok);
21
+ prob_obj.Set("tok_str", Napi::String::New(env, tok_str));
22
+ prob_obj.Set("prob", Napi::Number::New(env, p.prob));
23
+ token_probs.Set(j, prob_obj);
21
24
  }
25
+ token_obj.Set("probs", token_probs);
26
+ result.Set(i, token_obj);
22
27
  }
23
-
24
- return stop_pos;
28
+ return result;
25
29
  }
26
30
 
31
+
27
32
  LlamaCompletionWorker::LlamaCompletionWorker(
28
- const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
33
+ const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
29
34
  Napi::Function callback,
30
35
  common_params params,
31
36
  std::vector<std::string> stop_words,
@@ -35,9 +40,9 @@ LlamaCompletionWorker::LlamaCompletionWorker(
35
40
  const std::vector<std::string> &media_paths,
36
41
  const std::vector<llama_token> &guide_tokens,
37
42
  bool has_vocoder,
38
- tts_type tts_type_val,
43
+ rnllama::tts_type tts_type_val,
39
44
  const std::string &prefill_text)
40
- : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess),
45
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx),
41
46
  _params(params), _stop_words(stop_words), _chat_format(chat_format),
42
47
  _thinking_forced_open(thinking_forced_open),
43
48
  _reasoning_format(reasoning_format),
@@ -57,298 +62,203 @@ LlamaCompletionWorker::~LlamaCompletionWorker() {
57
62
  }
58
63
  }
59
64
 
60
- LlamaCompletionWorker::PartialOutput LlamaCompletionWorker::getPartialOutput(const std::string &generated_text) {
61
- PartialOutput result;
62
-
65
+
66
+ void LlamaCompletionWorker::Execute() {
63
67
  try {
64
- common_chat_syntax chat_syntax;
65
- chat_syntax.format = static_cast<common_chat_format>(_chat_format);
66
- chat_syntax.thinking_forced_open = _thinking_forced_open;
68
+ // Check if vocab_only mode is enabled - if so, return empty result
69
+ if (_params.vocab_only) {
70
+ // Return empty completion result for vocab_only mode
71
+ _result.tokens_evaluated = 0;
72
+ _result.tokens_predicted = 0;
73
+ _result.text = "";
74
+ _result.stopped_limited = true;
75
+ _result.truncated = false;
76
+ _result.context_full = false;
77
+ _result.stopped_eos = false;
78
+ _result.stopped_words = false;
79
+ if (_onComplete) {
80
+ _onComplete();
81
+ }
82
+ return;
83
+ }
84
+
85
+ auto completion = _rn_ctx->completion;
67
86
 
68
- // Set reasoning format using the common function
69
- chat_syntax.reasoning_format = common_reasoning_format_from_name(_reasoning_format);
87
+ // Prepare completion context
88
+ completion->rewind();
70
89
 
71
- chat_syntax.parse_tool_calls = true;
90
+ // Set up parameters
91
+ _rn_ctx->params.prompt = _params.prompt;
92
+ _rn_ctx->params.sampling = _params.sampling;
93
+ _rn_ctx->params.antiprompt = _stop_words;
94
+ _rn_ctx->params.n_predict = _params.n_predict;
95
+ _rn_ctx->params.n_ctx = _params.n_ctx;
96
+ _rn_ctx->params.n_batch = _params.n_batch;
97
+ _rn_ctx->params.ctx_shift = _params.ctx_shift;
72
98
 
73
- // Combine prefill_text with generated_text for parsing
74
- std::string full_text = _prefill_text + generated_text;
99
+ // Set prefill text
100
+ completion->prefill_text = _prefill_text;
75
101
 
76
- // Use is_partial=true for streaming partial output
77
- common_chat_msg parsed_msg = common_chat_parse(full_text, true, chat_syntax);
102
+ // Set up TTS guide tokens if enabled
103
+ if (_has_vocoder && _rn_ctx->tts_wrapper != nullptr) {
104
+ _rn_ctx->tts_wrapper->guide_tokens = _guide_tokens;
105
+ _rn_ctx->tts_wrapper->next_token_uses_guide_token = true;
106
+ }
78
107
 
79
- result.content = parsed_msg.content;
80
- result.reasoning_content = parsed_msg.reasoning_content;
81
- result.tool_calls = parsed_msg.tool_calls;
82
- } catch (const std::exception &e) {
83
- // If parsing fails, leave content empty - this is expected for partial content
84
- }
85
-
86
- return result;
87
- }
88
-
89
- void LlamaCompletionWorker::Execute() {
90
- _sess->get_mutex().lock();
91
- const auto t_main_start = ggml_time_us();
92
- const size_t n_ctx = _params.n_ctx;
93
- const auto n_keep = _params.n_keep;
94
- size_t n_cur = 0;
95
- size_t n_input = 0;
96
- const auto model = _sess->model();
97
- auto vocab = llama_model_get_vocab(model);
98
- const bool is_enc_dec = llama_model_has_encoder(model);
99
-
100
- const bool add_bos = llama_vocab_get_add_bos(vocab);
101
- auto ctx = _sess->context();
102
-
103
- auto sparams = llama_sampler_chain_default_params();
104
-
105
- LlamaCppSampling sampling{common_sampler_init(model, _params.sampling),
106
- common_sampler_free};
107
-
108
- // Process media if any are provided
109
- if (!_media_paths.empty()) {
110
- auto *mtmd_ctx = _sess->get_mtmd_ctx();
111
-
112
- if (mtmd_ctx != nullptr) {
113
- // Process the media and get the tokens
114
- try {
115
- n_cur = processMediaPrompt(ctx, mtmd_ctx, _sess, _params, _media_paths);
116
- } catch (const std::exception &e) {
117
- SetError(e.what());
118
- _sess->get_mutex().unlock();
119
- return;
120
- }
121
-
122
- if (n_cur <= 0) {
123
- SetError("Failed to process media");
124
- _sess->get_mutex().unlock();
125
- return;
126
- }
127
-
128
- fprintf(stdout,
129
- "[DEBUG] Media processing successful, n_cur=%zu, tokens=%zu\n",
130
- n_cur, _sess->tokens_ptr()->size());
131
-
132
- n_input = _sess->tokens_ptr()->size();
133
- if (n_cur == n_input) {
134
- --n_cur;
135
- }
136
- n_input -= n_cur;
137
- } else {
138
- SetError("Multimodal context not initialized");
139
- _sess->get_mutex().unlock();
108
+ // Initialize sampling
109
+ if (!completion->initSampling()) {
110
+ SetError("Failed to initialize sampling");
140
111
  return;
141
112
  }
142
- } else {
143
- // Text-only path
144
- std::vector<llama_token> prompt_tokens =
145
- ::common_tokenize(ctx, _params.prompt, add_bos || is_enc_dec, true);
146
- n_input = prompt_tokens.size();
147
-
148
- if (_sess->tokens_ptr()->size() > 0) {
149
- n_cur = common_tokens_part(*(_sess->tokens_ptr()), prompt_tokens);
150
- if (n_cur == n_input) {
151
- --n_cur;
152
- }
153
- n_input -= n_cur;
154
- llama_memory_seq_rm(llama_get_memory(ctx), 0, n_cur, -1);
113
+
114
+ // Load prompt (handles both text-only and multimodal)
115
+ completion->loadPrompt(_media_paths);
116
+
117
+ // Check if context is full after loading prompt
118
+ if (completion->context_full) {
119
+ _result.context_full = true;
120
+ return;
155
121
  }
156
- // Set the tokens
157
- _sess->set_tokens(std::move(prompt_tokens));
158
- }
159
-
160
- const int max_len = _params.n_predict < 0 ? std::numeric_limits<int>::max() : _params.n_predict;
161
- auto embd = _sess->tokens_ptr();
162
- embd->reserve(embd->size() + max_len);
163
-
164
-
165
- if (is_enc_dec) {
166
- if (n_input > 0) {
167
- // Decode tokens in batches using n_batch as chunk size
168
- int n_past_batch = n_cur;
169
- int n_remaining = n_input;
122
+
123
+ // Begin completion with chat format and reasoning settings
124
+ completion->beginCompletion(_chat_format, common_reasoning_format_from_name(_reasoning_format), _thinking_forced_open);
125
+
126
+ // Main completion loop
127
+ int token_count = 0;
128
+ const int max_tokens = _params.n_predict < 0 ? std::numeric_limits<int>::max() : _params.n_predict;
129
+ while (completion->has_next_token && !_interrupted && token_count < max_tokens) {
130
+ // Get next token using rn-llama completion
131
+ rnllama::completion_token_output token_output = completion->doCompletion();
170
132
 
171
- while (n_remaining > 0) {
172
- int n_eval = n_remaining;
173
- if (n_eval > _params.n_batch) {
174
- n_eval = _params.n_batch;
175
- }
176
-
177
- int ret = llama_encode(ctx, llama_batch_get_one(embd->data() + n_past_batch, n_eval));
178
- if (ret < 0) {
179
- SetError("Failed to encode token batch, code: " + std::to_string(ret) +
180
- ", n_eval: " + std::to_string(n_eval) +
181
- ", n_past_batch: " + std::to_string(n_past_batch));
182
- _sess->get_mutex().unlock();
183
- return;
184
- }
185
-
186
- n_past_batch += n_eval;
187
- n_remaining -= n_eval;
188
- n_cur += n_eval;
189
- }
190
- }
191
- _result.tokens_evaluated += n_input;
192
-
193
- llama_token decode_bos = llama_model_decoder_start_token(model);
194
- if (decode_bos == LLAMA_TOKEN_NULL) {
195
- decode_bos = llama_vocab_bos(vocab);
196
- }
197
-
198
- embd->emplace_back(decode_bos);
199
- common_sampler_accept(sampling.get(), decode_bos, false);
200
- n_input = 1;
201
- }
202
-
203
- for (int i = 0; (i < max_len || _interrupted) && !_params.vocab_only; i++) {
204
- // check if we need to remove some tokens
205
- if (embd->size() >= _params.n_ctx) {
206
- if (!_params.ctx_shift) {
207
- // Context is full and ctx_shift is disabled, so we need to stop
208
- _result.context_full = true;
133
+ if (token_output.tok == -1) {
209
134
  break;
210
135
  }
211
-
212
- const int n_left = n_cur - n_keep - 1;
213
- const int n_discard = n_left / 2;
214
-
215
- auto mem = llama_get_memory(ctx);
216
- llama_memory_seq_rm(mem, 0, n_keep + 1, n_keep + n_discard + 1);
217
- llama_memory_seq_add(mem, 0, n_keep + 1 + n_discard, n_cur, -n_discard);
218
-
219
- // shift the tokens
220
- embd->insert(embd->begin() + n_keep + 1,
221
- embd->begin() + n_keep + 1 + n_discard, embd->end());
222
- embd->resize(embd->size() - n_discard);
223
-
224
- n_cur -= n_discard;
225
- _result.truncated = true;
226
- }
227
-
228
- // For multimodal input, n_past might already be set
229
- // Only decode text tokens if we have any input left
230
- if (n_input > 0) {
231
- // Decode tokens in batches using n_batch as chunk size
232
- int n_past_batch = n_cur;
233
- int n_remaining = n_input;
234
136
 
235
- while (n_remaining > 0) {
236
- int n_eval = n_remaining;
237
- if (n_eval > _params.n_batch) {
238
- n_eval = _params.n_batch;
137
+ token_count++;
138
+
139
+ std::string token_text = common_token_to_piece(_rn_ctx->ctx, token_output.tok);
140
+ _result.text += token_text;
141
+
142
+ // Check for stopping strings after adding the token
143
+ if (!_stop_words.empty()) {
144
+ size_t stop_pos = completion->findStoppingStrings(_result.text, token_text.size(), rnllama::STOP_FULL);
145
+ if (stop_pos != std::string::npos) {
146
+ // Found a stop word, truncate the result and break
147
+ _result.text = _result.text.substr(0, stop_pos);
148
+ break;
239
149
  }
240
-
241
- int ret = llama_decode(ctx, llama_batch_get_one(embd->data() + n_past_batch, n_eval));
242
- if (ret < 0) {
243
- SetError("Failed to decode token batch, code: " + std::to_string(ret) +
244
- ", n_eval: " + std::to_string(n_eval) +
245
- ", n_past_batch: " + std::to_string(n_past_batch));
246
- _sess->get_mutex().unlock();
247
- return;
150
+ }
151
+
152
+ // Handle streaming callback
153
+ if (_has_callback && !completion->incomplete) {
154
+ struct TokenData {
155
+ std::string token;
156
+ std::string content;
157
+ std::string reasoning_content;
158
+ std::vector<common_chat_tool_call> tool_calls;
159
+ std::string accumulated_text;
160
+ std::vector<rnllama::completion_token_output> completion_probabilities;
161
+ llama_context* ctx;
162
+ };
163
+
164
+ auto partial_output = completion->parseChatOutput(true);
165
+
166
+ // Extract completion probabilities if n_probs > 0, similar to iOS implementation
167
+ std::vector<rnllama::completion_token_output> probs_output;
168
+ if (_rn_ctx->params.sampling.n_probs > 0) {
169
+ const std::vector<llama_token> to_send_toks = common_tokenize(_rn_ctx->ctx, token_text, false);
170
+ size_t probs_pos = std::min(_sent_token_probs_index, completion->generated_token_probs.size());
171
+ size_t probs_stop_pos = std::min(_sent_token_probs_index + to_send_toks.size(), completion->generated_token_probs.size());
172
+ if (probs_pos < probs_stop_pos) {
173
+ probs_output = std::vector<rnllama::completion_token_output>(
174
+ completion->generated_token_probs.begin() + probs_pos,
175
+ completion->generated_token_probs.begin() + probs_stop_pos
176
+ );
177
+ }
178
+ _sent_token_probs_index = probs_stop_pos;
248
179
  }
249
180
 
250
- n_past_batch += n_eval;
251
- n_remaining -= n_eval;
181
+ TokenData *token_data = new TokenData{
182
+ token_text,
183
+ partial_output.content,
184
+ partial_output.reasoning_content,
185
+ partial_output.tool_calls,
186
+ partial_output.accumulated_text,
187
+ probs_output,
188
+ _rn_ctx->ctx
189
+ };
190
+
191
+ _tsfn.BlockingCall(token_data, [](Napi::Env env, Napi::Function jsCallback,
192
+ TokenData *data) {
193
+ auto obj = Napi::Object::New(env);
194
+ obj.Set("token", Napi::String::New(env, data->token));
195
+ if (!data->content.empty()) {
196
+ obj.Set("content", Napi::String::New(env, data->content));
197
+ }
198
+ if (!data->reasoning_content.empty()) {
199
+ obj.Set("reasoning_content", Napi::String::New(env, data->reasoning_content));
200
+ }
201
+ if (!data->tool_calls.empty()) {
202
+ Napi::Array tool_calls = Napi::Array::New(env);
203
+ for (size_t i = 0; i < data->tool_calls.size(); i++) {
204
+ const auto &tc = data->tool_calls[i];
205
+ Napi::Object tool_call = Napi::Object::New(env);
206
+ tool_call.Set("type", "function");
207
+ Napi::Object function = Napi::Object::New(env);
208
+ function.Set("name", tc.name);
209
+ function.Set("arguments", tc.arguments);
210
+ tool_call.Set("function", function);
211
+ if (!tc.id.empty()) {
212
+ tool_call.Set("id", tc.id);
213
+ }
214
+ tool_calls.Set(i, tool_call);
215
+ }
216
+ obj.Set("tool_calls", tool_calls);
217
+ }
218
+ obj.Set("accumulated_text", Napi::String::New(env, data->accumulated_text));
219
+
220
+ // Add completion_probabilities if available
221
+ if (!data->completion_probabilities.empty()) {
222
+ obj.Set("completion_probabilities", TokenProbsToArray(env, data->ctx, data->completion_probabilities));
223
+ }
224
+
225
+ delete data;
226
+ jsCallback.Call({obj});
227
+ });
252
228
  }
253
229
  }
254
-
255
- // sample the next token
256
- llama_token new_token_id = common_sampler_sample(sampling.get(), ctx, -1);
257
-
258
- // is it an end of generation?
259
- if (llama_vocab_is_eog(vocab, new_token_id)) {
260
- _result.stopped_eos = true;
261
- break;
262
- }
263
-
264
- if (_next_token_uses_guide_token && !_guide_tokens.empty() &&
265
- !llama_vocab_is_control(vocab, new_token_id) &&
266
- !llama_vocab_is_eog(vocab, new_token_id)) {
267
- new_token_id = _guide_tokens[0];
268
- _guide_tokens.erase(_guide_tokens.begin());
269
- }
270
- _next_token_uses_guide_token = (new_token_id == 198);
271
- common_sampler_accept(sampling.get(), new_token_id, true);
272
230
 
273
- // Collect audio tokens for TTS if vocoder is enabled
274
- if (_has_vocoder) {
275
- if ((_tts_type == OUTETTS_V0_1 || _tts_type == OUTETTS_V0_2 || _tts_type == OUTETTS_V0_3) &&
276
- (new_token_id >= 151672 && new_token_id <= 155772)) {
277
- _result.audio_tokens.push_back(new_token_id);
278
- }
231
+ // Check stopping conditions
232
+ if (token_count >= max_tokens) {
233
+ _result.stopped_limited = true;
234
+ } else if (!completion->has_next_token && completion->n_remain == 0) {
235
+ _result.stopped_limited = true;
279
236
  }
280
237
 
281
- // prepare the next batch
282
- embd->emplace_back(new_token_id);
283
- auto token = common_token_to_piece(ctx, new_token_id);
284
- _result.text += token;
285
- n_cur += n_input;
286
- _result.tokens_evaluated += n_input;
287
- _result.tokens_predicted += 1;
288
- n_input = 1;
289
- if (_has_callback) {
290
- // TODO: When we got possible stop words (startsWith)
291
- // we should avoid calling the callback, wait for the next token
292
- struct TokenData {
293
- std::string token;
294
- std::string content;
295
- std::string reasoning_content;
296
- std::vector<common_chat_tool_call> tool_calls;
297
- std::string accumulated_text;
298
- };
299
-
300
- auto partial = getPartialOutput(_result.text);
301
- TokenData *token_data = new TokenData{token, partial.content, partial.reasoning_content, partial.tool_calls, _result.text};
302
-
303
- _tsfn.BlockingCall(token_data, [](Napi::Env env, Napi::Function jsCallback,
304
- TokenData *data) {
305
- auto obj = Napi::Object::New(env);
306
- obj.Set("token", Napi::String::New(env, data->token));
307
- if (!data->content.empty()) {
308
- obj.Set("content", Napi::String::New(env, data->content));
309
- }
310
- if (!data->reasoning_content.empty()) {
311
- obj.Set("reasoning_content", Napi::String::New(env, data->reasoning_content));
312
- }
313
- if (!data->tool_calls.empty()) {
314
- Napi::Array tool_calls = Napi::Array::New(env);
315
- for (size_t i = 0; i < data->tool_calls.size(); i++) {
316
- const auto &tc = data->tool_calls[i];
317
- Napi::Object tool_call = Napi::Object::New(env);
318
- tool_call.Set("type", "function");
319
- Napi::Object function = Napi::Object::New(env);
320
- function.Set("name", tc.name);
321
- function.Set("arguments", tc.arguments);
322
- tool_call.Set("function", function);
323
- if (!tc.id.empty()) {
324
- tool_call.Set("id", tc.id);
325
- }
326
- tool_calls.Set(i, tool_call);
327
- }
328
- obj.Set("tool_calls", tool_calls);
329
- }
330
- obj.Set("accumulated_text", Napi::String::New(env, data->accumulated_text));
331
- delete data;
332
- jsCallback.Call({obj});
333
- });
334
- }
335
- // check for stop words
336
- if (!_stop_words.empty()) {
337
- const size_t stop_pos =
338
- findStoppingStrings(_result.text, token.size(), _stop_words);
339
- if (stop_pos != std::string::npos) {
340
- _result.stopped_words = true;
341
- _result.stopping_word = _result.text.substr(stop_pos, token.size());
342
- _result.text = _result.text.substr(0, stop_pos - 1);
343
- break;
344
- }
238
+ // Set completion results from rn-llama completion context
239
+ // tokens_evaluated should include both prompt tokens and generated tokens that were processed
240
+ _result.tokens_evaluated = completion->num_prompt_tokens + completion->num_tokens_predicted;
241
+ _result.tokens_predicted = completion->num_tokens_predicted;
242
+ _result.truncated = completion->truncated;
243
+ _result.context_full = completion->context_full;
244
+ _result.stopped_eos = completion->stopped_eos;
245
+ _result.stopped_words = completion->stopped_word;
246
+ _result.stopping_word = completion->stopping_word;
247
+ _result.stopped_limited = completion->stopped_limit;
248
+
249
+ // Get audio tokens if TTS is enabled
250
+ if (_has_vocoder && _rn_ctx->tts_wrapper != nullptr) {
251
+ _result.audio_tokens = _rn_ctx->tts_wrapper->audio_tokens;
345
252
  }
253
+
254
+ // End completion
255
+ completion->endCompletion();
256
+
257
+ } catch (const std::exception &e) {
258
+ SetError(e.what());
259
+ return;
346
260
  }
347
- if (!_result.stopped_eos && !_result.stopped_words) {
348
- _result.stopped_limited = true;
349
- }
350
- const auto t_main_end = ggml_time_us();
351
- _sess->get_mutex().unlock();
261
+
352
262
  if (_onComplete) {
353
263
  _onComplete();
354
264
  }
@@ -357,6 +267,7 @@ void LlamaCompletionWorker::Execute() {
357
267
  void LlamaCompletionWorker::OnOK() {
358
268
  auto env = Napi::AsyncWorker::Env();
359
269
  auto result = Napi::Object::New(env);
270
+ result.Set("chat_format", Napi::Number::New(env, _chat_format));
360
271
  result.Set("tokens_evaluated",
361
272
  Napi::Number::New(env, _result.tokens_evaluated));
362
273
  result.Set("tokens_predicted", Napi::Number::New(Napi::AsyncWorker::Env(),
@@ -364,7 +275,9 @@ void LlamaCompletionWorker::OnOK() {
364
275
  result.Set("truncated", Napi::Boolean::New(env, _result.truncated));
365
276
  result.Set("context_full", Napi::Boolean::New(env, _result.context_full));
366
277
  result.Set("interrupted", Napi::Boolean::New(env, _interrupted));
367
- result.Set("text", Napi::String::New(env, _result.text.c_str()));
278
+ // Use the generated text from rn-llama completion if available, otherwise use our result text
279
+ std::string final_text = (_rn_ctx->completion != nullptr) ? _rn_ctx->completion->generated_text : _result.text;
280
+ result.Set("text", Napi::String::New(env, final_text.c_str()));
368
281
  result.Set("stopped_eos", Napi::Boolean::New(env, _result.stopped_eos));
369
282
  result.Set("stopped_words", Napi::Boolean::New(env, _result.stopped_words));
370
283
  result.Set("stopping_word",
@@ -375,31 +288,15 @@ void LlamaCompletionWorker::OnOK() {
375
288
  Napi::Array tool_calls = Napi::Array::New(Napi::AsyncWorker::Env());
376
289
  std::string reasoning_content = "";
377
290
  std::string content;
378
- if (!_interrupted) {
291
+ if (!_interrupted && _rn_ctx->completion != nullptr) {
379
292
  try {
380
- common_chat_syntax chat_syntax;
381
- chat_syntax.format = static_cast<common_chat_format>(_chat_format);
382
- result.Set("chat_format", Napi::Number::New(env, _chat_format));
293
+ auto final_output = _rn_ctx->completion->parseChatOutput(false);
294
+ reasoning_content = final_output.reasoning_content;
295
+ content = final_output.content;
383
296
 
384
- chat_syntax.thinking_forced_open = _thinking_forced_open;
385
-
386
- chat_syntax.reasoning_format = common_reasoning_format_from_name(_reasoning_format);
387
-
388
- // Combine prefill_text with generated_text for final parsing
389
- std::string full_text = _prefill_text + _result.text;
390
- common_chat_msg message = common_chat_parse(
391
- full_text,
392
- false,
393
- chat_syntax
394
- );
395
- if (!message.reasoning_content.empty()) {
396
- reasoning_content = message.reasoning_content;
397
- }
398
- if (!message.content.empty()) {
399
- content = message.content;
400
- }
401
- for (size_t i = 0; i < message.tool_calls.size(); i++) {
402
- const auto &tc = message.tool_calls[i];
297
+ // Convert tool calls to JavaScript format
298
+ for (size_t i = 0; i < final_output.tool_calls.size(); i++) {
299
+ const auto &tc = final_output.tool_calls[i];
403
300
  Napi::Object tool_call = Napi::Object::New(env);
404
301
  tool_call.Set("type", "function");
405
302
  Napi::Object function = Napi::Object::New(env);
@@ -435,7 +332,12 @@ void LlamaCompletionWorker::OnOK() {
435
332
  result.Set("audio_tokens", audio_tokens);
436
333
  }
437
334
 
438
- auto ctx = _sess->context();
335
+ // Add completion_probabilities to final result
336
+ if (_rn_ctx->params.sampling.n_probs > 0 && _rn_ctx->completion != nullptr && !_rn_ctx->completion->generated_token_probs.empty()) {
337
+ result.Set("completion_probabilities", TokenProbsToArray(env, _rn_ctx->ctx, _rn_ctx->completion->generated_token_probs));
338
+ }
339
+
340
+ auto ctx = _rn_ctx->ctx;
439
341
  const auto timings_token = llama_perf_context(ctx);
440
342
 
441
343
  auto timingsResult = Napi::Object::New(Napi::AsyncWorker::Env());
@@ -1,7 +1,7 @@
1
1
  #pragma once
2
2
 
3
3
  #include "common.hpp"
4
- #include "tts_utils.h"
4
+ #include "rn-llama/rn-llama.h"
5
5
  #include <atomic>
6
6
  #include <functional>
7
7
  #include <napi.h>
@@ -17,7 +17,7 @@ struct CompletionResult {
17
17
  class LlamaCompletionWorker : public Napi::AsyncWorker,
18
18
  public Napi::Promise::Deferred {
19
19
  public:
20
- LlamaCompletionWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
20
+ LlamaCompletionWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
21
21
  Napi::Function callback, common_params params,
22
22
  std::vector<std::string> stop_words,
23
23
  int32_t chat_format,
@@ -26,7 +26,7 @@ public:
26
26
  const std::vector<std::string> &media_paths = {},
27
27
  const std::vector<llama_token> &guide_tokens = {},
28
28
  bool has_vocoder = false,
29
- tts_type tts_type_val = UNKNOWN,
29
+ rnllama::tts_type tts_type_val = rnllama::UNKNOWN,
30
30
  const std::string &prefill_text = "");
31
31
 
32
32
  ~LlamaCompletionWorker();
@@ -43,15 +43,8 @@ protected:
43
43
  void OnError(const Napi::Error &err) override;
44
44
 
45
45
  private:
46
- struct PartialOutput {
47
- std::string content = "";
48
- std::string reasoning_content = "";
49
- std::vector<common_chat_tool_call> tool_calls;
50
- };
51
46
 
52
- PartialOutput getPartialOutput(const std::string &generated_text);
53
-
54
- LlamaSessionPtr _sess;
47
+ rnllama::llama_rn_context* _rn_ctx;
55
48
  common_params _params;
56
49
  std::vector<std::string> _stop_words;
57
50
  int32_t _chat_format;
@@ -66,7 +59,8 @@ private:
66
59
  Napi::ThreadSafeFunction _tsfn;
67
60
  bool _next_token_uses_guide_token = true;
68
61
  bool _has_vocoder;
69
- tts_type _tts_type;
62
+ rnllama::tts_type _tts_type;
63
+ size_t _sent_token_probs_index = 0;
70
64
  struct {
71
65
  size_t tokens_evaluated = 0;
72
66
  size_t tokens_predicted = 0;