llama-cpp-capacitor 0.0.6 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. package/cpp/LICENSE +21 -0
  2. package/cpp/README.md +4 -0
  3. package/cpp/anyascii.c +22223 -0
  4. package/cpp/anyascii.h +42 -0
  5. package/cpp/chat-parser.cpp +393 -0
  6. package/cpp/chat-parser.h +120 -0
  7. package/cpp/chat.cpp +2315 -0
  8. package/cpp/chat.h +221 -0
  9. package/cpp/common.cpp +1619 -0
  10. package/cpp/common.h +744 -0
  11. package/cpp/ggml-alloc.c +1028 -0
  12. package/cpp/ggml-alloc.h +76 -0
  13. package/cpp/ggml-backend-impl.h +255 -0
  14. package/cpp/ggml-backend-reg.cpp +600 -0
  15. package/cpp/ggml-backend.cpp +2118 -0
  16. package/cpp/ggml-backend.h +354 -0
  17. package/cpp/ggml-common.h +1878 -0
  18. package/cpp/ggml-cpp.h +39 -0
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2512 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  25. package/cpp/ggml-cpu/arch/arm/quants.c +3650 -0
  26. package/cpp/ggml-cpu/arch/arm/repack.cpp +1891 -0
  27. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  28. package/cpp/ggml-cpu/arch/x86/quants.c +3820 -0
  29. package/cpp/ggml-cpu/arch/x86/repack.cpp +6307 -0
  30. package/cpp/ggml-cpu/arch-fallback.h +215 -0
  31. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  32. package/cpp/ggml-cpu/binary-ops.h +16 -0
  33. package/cpp/ggml-cpu/common.h +73 -0
  34. package/cpp/ggml-cpu/ggml-cpu-impl.h +525 -0
  35. package/cpp/ggml-cpu/ggml-cpu.c +3578 -0
  36. package/cpp/ggml-cpu/ggml-cpu.cpp +672 -0
  37. package/cpp/ggml-cpu/ops.cpp +10587 -0
  38. package/cpp/ggml-cpu/ops.h +114 -0
  39. package/cpp/ggml-cpu/quants.c +1193 -0
  40. package/cpp/ggml-cpu/quants.h +97 -0
  41. package/cpp/ggml-cpu/repack.cpp +1982 -0
  42. package/cpp/ggml-cpu/repack.h +120 -0
  43. package/cpp/ggml-cpu/simd-mappings.h +1184 -0
  44. package/cpp/ggml-cpu/traits.cpp +36 -0
  45. package/cpp/ggml-cpu/traits.h +38 -0
  46. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  47. package/cpp/ggml-cpu/unary-ops.h +28 -0
  48. package/cpp/ggml-cpu/vec.cpp +348 -0
  49. package/cpp/ggml-cpu/vec.h +1121 -0
  50. package/cpp/ggml-cpu.h +145 -0
  51. package/cpp/ggml-impl.h +622 -0
  52. package/cpp/ggml-metal-impl.h +688 -0
  53. package/cpp/ggml-metal.h +66 -0
  54. package/cpp/ggml-metal.m +6833 -0
  55. package/cpp/ggml-opt.cpp +1093 -0
  56. package/cpp/ggml-opt.h +256 -0
  57. package/cpp/ggml-quants.c +5324 -0
  58. package/cpp/ggml-quants.h +106 -0
  59. package/cpp/ggml-threading.cpp +12 -0
  60. package/cpp/ggml-threading.h +14 -0
  61. package/cpp/ggml.c +7108 -0
  62. package/cpp/ggml.h +2492 -0
  63. package/cpp/gguf.cpp +1358 -0
  64. package/cpp/gguf.h +202 -0
  65. package/cpp/json-partial.cpp +256 -0
  66. package/cpp/json-partial.h +38 -0
  67. package/cpp/json-schema-to-grammar.cpp +985 -0
  68. package/cpp/json-schema-to-grammar.h +21 -0
  69. package/cpp/llama-adapter.cpp +388 -0
  70. package/cpp/llama-adapter.h +76 -0
  71. package/cpp/llama-arch.cpp +2355 -0
  72. package/cpp/llama-arch.h +499 -0
  73. package/cpp/llama-batch.cpp +875 -0
  74. package/cpp/llama-batch.h +160 -0
  75. package/cpp/llama-chat.cpp +783 -0
  76. package/cpp/llama-chat.h +65 -0
  77. package/cpp/llama-context.cpp +2748 -0
  78. package/cpp/llama-context.h +306 -0
  79. package/cpp/llama-cparams.cpp +5 -0
  80. package/cpp/llama-cparams.h +41 -0
  81. package/cpp/llama-cpp.h +30 -0
  82. package/cpp/llama-grammar.cpp +1229 -0
  83. package/cpp/llama-grammar.h +173 -0
  84. package/cpp/llama-graph.cpp +1891 -0
  85. package/cpp/llama-graph.h +810 -0
  86. package/cpp/llama-hparams.cpp +180 -0
  87. package/cpp/llama-hparams.h +233 -0
  88. package/cpp/llama-impl.cpp +167 -0
  89. package/cpp/llama-impl.h +61 -0
  90. package/cpp/llama-io.cpp +15 -0
  91. package/cpp/llama-io.h +35 -0
  92. package/cpp/llama-kv-cache-iswa.cpp +318 -0
  93. package/cpp/llama-kv-cache-iswa.h +135 -0
  94. package/cpp/llama-kv-cache.cpp +2059 -0
  95. package/cpp/llama-kv-cache.h +374 -0
  96. package/cpp/llama-kv-cells.h +491 -0
  97. package/cpp/llama-memory-hybrid.cpp +258 -0
  98. package/cpp/llama-memory-hybrid.h +137 -0
  99. package/cpp/llama-memory-recurrent.cpp +1146 -0
  100. package/cpp/llama-memory-recurrent.h +179 -0
  101. package/cpp/llama-memory.cpp +59 -0
  102. package/cpp/llama-memory.h +119 -0
  103. package/cpp/llama-mmap.cpp +600 -0
  104. package/cpp/llama-mmap.h +68 -0
  105. package/cpp/llama-model-loader.cpp +1164 -0
  106. package/cpp/llama-model-loader.h +170 -0
  107. package/cpp/llama-model-saver.cpp +282 -0
  108. package/cpp/llama-model-saver.h +37 -0
  109. package/cpp/llama-model.cpp +19042 -0
  110. package/cpp/llama-model.h +491 -0
  111. package/cpp/llama-sampling.cpp +2575 -0
  112. package/cpp/llama-sampling.h +32 -0
  113. package/cpp/llama-vocab.cpp +3792 -0
  114. package/cpp/llama-vocab.h +176 -0
  115. package/cpp/llama.cpp +358 -0
  116. package/cpp/llama.h +1373 -0
  117. package/cpp/log.cpp +427 -0
  118. package/cpp/log.h +103 -0
  119. package/cpp/minja/chat-template.hpp +550 -0
  120. package/cpp/minja/minja.hpp +3009 -0
  121. package/cpp/nlohmann/json.hpp +25526 -0
  122. package/cpp/nlohmann/json_fwd.hpp +187 -0
  123. package/cpp/regex-partial.cpp +204 -0
  124. package/cpp/regex-partial.h +56 -0
  125. package/cpp/rn-completion.cpp +681 -0
  126. package/cpp/rn-completion.h +116 -0
  127. package/cpp/rn-llama.cpp +345 -0
  128. package/cpp/rn-llama.h +149 -0
  129. package/cpp/rn-mtmd.hpp +602 -0
  130. package/cpp/rn-tts.cpp +591 -0
  131. package/cpp/rn-tts.h +59 -0
  132. package/cpp/sampling.cpp +579 -0
  133. package/cpp/sampling.h +107 -0
  134. package/cpp/tools/mtmd/clip-impl.h +473 -0
  135. package/cpp/tools/mtmd/clip.cpp +4322 -0
  136. package/cpp/tools/mtmd/clip.h +106 -0
  137. package/cpp/tools/mtmd/miniaudio/miniaudio.h +93468 -0
  138. package/cpp/tools/mtmd/mtmd-audio.cpp +769 -0
  139. package/cpp/tools/mtmd/mtmd-audio.h +47 -0
  140. package/cpp/tools/mtmd/mtmd-helper.cpp +460 -0
  141. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  142. package/cpp/tools/mtmd/mtmd.cpp +1066 -0
  143. package/cpp/tools/mtmd/mtmd.h +298 -0
  144. package/cpp/tools/mtmd/stb/stb_image.h +7988 -0
  145. package/cpp/unicode-data.cpp +7034 -0
  146. package/cpp/unicode-data.h +20 -0
  147. package/cpp/unicode.cpp +1061 -0
  148. package/cpp/unicode.h +68 -0
  149. package/package.json +2 -1
@@ -0,0 +1,681 @@
1
+ #include "rn-completion.h"
2
+ #include "rn-llama.h"
3
+ #include "rn-tts.h"
4
+ #include "rn-mtmd.hpp"
5
+
6
+ // Include multimodal support
7
+ #include "tools/mtmd/mtmd.h"
8
+ #include "tools/mtmd/mtmd-helper.h"
9
+ #include "tools/mtmd/clip.h"
10
+
11
+ namespace rnllama {
12
+
13
+ static bool ends_with(const std::string &str, const std::string &suffix)
14
+ {
15
+ return str.size() >= suffix.size() &&
16
+ 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
17
+ }
18
+
19
+ static size_t find_partial_stop_string(const std::string &stop,
20
+ const std::string &text)
21
+ {
22
+ if (!text.empty() && !stop.empty())
23
+ {
24
+ const char text_last_char = text.back();
25
+ for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
26
+ {
27
+ if (stop[char_index] == text_last_char)
28
+ {
29
+ const std::string current_partial = stop.substr(0, char_index + 1);
30
+ if (ends_with(text, current_partial))
31
+ {
32
+ return text.size() - char_index - 1;
33
+ }
34
+ }
35
+ }
36
+ }
37
+ return std::string::npos;
38
+ }
39
+
40
+ // Helper function to format rerank task: [BOS]query[EOS][SEP]doc[EOS]
41
+ static std::vector<llama_token> format_rerank(const llama_vocab * vocab, const std::vector<llama_token> & query, const std::vector<llama_token> & doc) {
42
+ std::vector<llama_token> result;
43
+
44
+ // Get EOS token - use SEP token as fallback if EOS is not available
45
+ llama_token eos_token = llama_vocab_eos(vocab);
46
+ if (eos_token == LLAMA_TOKEN_NULL) {
47
+ eos_token = llama_vocab_sep(vocab);
48
+ }
49
+
50
+ result.reserve(doc.size() + query.size() + 4);
51
+ if (llama_vocab_get_add_bos(vocab)) {
52
+ result.push_back(llama_vocab_bos(vocab));
53
+ }
54
+ result.insert(result.end(), query.begin(), query.end());
55
+ if (llama_vocab_get_add_eos(vocab)) {
56
+ result.push_back(eos_token);
57
+ }
58
+ if (llama_vocab_get_add_sep(vocab)) {
59
+ result.push_back(llama_vocab_sep(vocab));
60
+ }
61
+ result.insert(result.end(), doc.begin(), doc.end());
62
+ if (llama_vocab_get_add_eos(vocab)) {
63
+ result.push_back(eos_token);
64
+ }
65
+
66
+ return result;
67
+ }
68
+
69
+ // Constructor
70
+ llama_rn_context_completion::llama_rn_context_completion(llama_rn_context* parent)
71
+ : parent_ctx(parent) {
72
+ }
73
+
74
+ // Destructor
75
+ llama_rn_context_completion::~llama_rn_context_completion() {
76
+ if (ctx_sampling != nullptr) {
77
+ common_sampler_free(ctx_sampling);
78
+ ctx_sampling = nullptr;
79
+ }
80
+ }
81
+
82
+ void llama_rn_context_completion::rewind() {
83
+ is_interrupted = false;
84
+ parent_ctx->params.antiprompt.clear();
85
+ parent_ctx->params.sampling.grammar.clear();
86
+ num_prompt_tokens = 0;
87
+ num_tokens_predicted = 0;
88
+ prefill_text = "";
89
+ generated_text = "";
90
+ generated_text.reserve(parent_ctx->params.n_ctx);
91
+ truncated = false;
92
+ context_full = false;
93
+ stopped_eos = false;
94
+ stopped_word = false;
95
+ stopped_limit = false;
96
+ stopping_word = "";
97
+ incomplete = false;
98
+ n_remain = 0;
99
+ n_past = 0;
100
+ parent_ctx->params.sampling.n_prev = parent_ctx->n_ctx;
101
+ if (parent_ctx->isVocoderEnabled()) {
102
+ parent_ctx->tts_wrapper->audio_tokens.clear();
103
+ parent_ctx->tts_wrapper->next_token_uses_guide_token = true;
104
+ parent_ctx->tts_wrapper->guide_tokens.clear();
105
+ }
106
+ }
107
+
108
+ bool llama_rn_context_completion::initSampling() {
109
+ if (ctx_sampling != nullptr) {
110
+ common_sampler_free(ctx_sampling);
111
+ }
112
+ ctx_sampling = common_sampler_init(parent_ctx->model, parent_ctx->params.sampling);
113
+ return ctx_sampling != nullptr;
114
+ }
115
+
116
+ void llama_rn_context_completion::truncatePrompt(std::vector<llama_token> &prompt_tokens) {
117
+ const int n_left = parent_ctx->n_ctx - parent_ctx->params.n_keep;
118
+ const int n_block_size = n_left / 2;
119
+ const int erased_blocks = (prompt_tokens.size() - parent_ctx->params.n_keep - n_block_size) / n_block_size;
120
+
121
+ // Keep n_keep tokens at start of prompt (at most n_ctx - 4)
122
+ std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + parent_ctx->params.n_keep);
123
+
124
+ new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + parent_ctx->params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
125
+
126
+ LOG_INFO("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, old_size: %d, new_size: %d",
127
+ parent_ctx->n_ctx,
128
+ parent_ctx->params.n_keep,
129
+ n_left,
130
+ prompt_tokens.size(),
131
+ new_tokens.size()
132
+ );
133
+
134
+ truncated = true;
135
+ prompt_tokens = new_tokens;
136
+ }
137
+
138
+ void llama_rn_context_completion::loadPrompt(const std::vector<std::string> &media_paths) {
139
+ bool has_media = !media_paths.empty();
140
+
141
+ if (!has_media) {
142
+ std::vector<llama_token> text_tokens;
143
+ // Text-only path
144
+ text_tokens = ::common_tokenize(parent_ctx->ctx, parent_ctx->params.prompt, true, true);
145
+ num_prompt_tokens = text_tokens.size();
146
+
147
+ // LOG tokens
148
+ std::stringstream ss;
149
+ ss << "\n" << __func__ << ": prompt_tokens = ";
150
+ for (auto& token : text_tokens) {
151
+ ss << token << " ";
152
+ }
153
+ LOG_INFO("%s\n", ss.str().c_str());
154
+
155
+ if (parent_ctx->params.n_keep < 0) {
156
+ parent_ctx->params.n_keep = (int)num_prompt_tokens;
157
+ }
158
+ parent_ctx->params.n_keep = std::min(parent_ctx->n_ctx - 4, parent_ctx->params.n_keep);
159
+
160
+ // Handle truncation if needed
161
+ if (num_prompt_tokens >= (size_t)parent_ctx->n_ctx) {
162
+ if (!parent_ctx->params.ctx_shift) {
163
+ context_full = true;
164
+ return;
165
+ }
166
+ truncatePrompt(text_tokens);
167
+ num_prompt_tokens = text_tokens.size();
168
+ LM_GGML_ASSERT(num_prompt_tokens < (size_t)parent_ctx->n_ctx);
169
+ }
170
+
171
+ // Update sampling context
172
+ for (auto & token : text_tokens) {
173
+ common_sampler_accept(ctx_sampling, token, false);
174
+ }
175
+
176
+ // compare the evaluated prompt with the new prompt
177
+ n_past = common_part(embd, text_tokens);
178
+
179
+ embd = text_tokens;
180
+ if (n_past == num_prompt_tokens) {
181
+ // we have to evaluate at least 1 token to generate logits.
182
+ n_past--;
183
+ }
184
+
185
+ // Manage KV cache
186
+ auto * kv = llama_get_memory(parent_ctx->ctx);
187
+ llama_memory_seq_rm(kv, 0, n_past, -1);
188
+
189
+ LOG_VERBOSE("prompt ingested, n_past: %d, cached: %s, to_eval: %s",
190
+ n_past,
191
+ tokens_to_str(parent_ctx->ctx, embd.cbegin(), embd.cbegin() + n_past).c_str(),
192
+ tokens_to_str(parent_ctx->ctx, embd.cbegin() + n_past, embd.cend()).c_str()
193
+ );
194
+ } else {
195
+ // Multimodal path - process all media paths
196
+ processMedia(parent_ctx->params.prompt, media_paths);
197
+ num_prompt_tokens = embd.size();
198
+ }
199
+
200
+ has_next_token = true;
201
+
202
+ LOG_INFO("[DEBUG] Input processed: n_past=%d, embd.size=%zu, num_prompt_tokens=%zu, has_media=%d",
203
+ n_past, embd.size(), num_prompt_tokens, has_media ? 1 : 0);
204
+ }
205
+
206
+ void llama_rn_context_completion::beginCompletion() {
207
+ beginCompletion(COMMON_CHAT_FORMAT_CONTENT_ONLY, COMMON_REASONING_FORMAT_NONE, false);
208
+ }
209
+
210
+ void llama_rn_context_completion::beginCompletion(int chat_format, common_reasoning_format reasoning_format, bool thinking_forced_open) {
211
+ // number of tokens to keep when resetting context
212
+ n_remain = parent_ctx->params.n_predict;
213
+ llama_perf_context_reset(parent_ctx->ctx);
214
+ is_predicting = true;
215
+
216
+ current_chat_format = chat_format;
217
+ current_reasoning_format = reasoning_format;
218
+ current_thinking_forced_open = thinking_forced_open;
219
+ }
220
+
221
+ void llama_rn_context_completion::endCompletion() {
222
+ is_predicting = false;
223
+ }
224
+
225
+ completion_token_output llama_rn_context_completion::nextToken()
226
+ {
227
+ completion_token_output result;
228
+ result.tok = -1;
229
+
230
+ if (embd.size() >= (size_t)parent_ctx->params.n_ctx)
231
+ {
232
+ if (!parent_ctx->params.ctx_shift) {
233
+ // If context shifting is disabled, stop generation
234
+ LOG_WARNING("context full, n_ctx: %d, tokens: %d", parent_ctx->params.n_ctx, embd.size());
235
+ has_next_token = false;
236
+ context_full = true;
237
+ return result;
238
+ }
239
+
240
+ // Shift context
241
+
242
+ const int n_left = n_past - parent_ctx->params.n_keep - 1;
243
+ const int n_discard = n_left/2;
244
+
245
+ auto * kv = llama_get_memory(parent_ctx->ctx);
246
+ llama_memory_seq_rm (kv, 0, parent_ctx->params.n_keep + 1 , parent_ctx->params.n_keep + n_discard + 1);
247
+ llama_memory_seq_add(kv, 0, parent_ctx->params.n_keep + 1 + n_discard, n_past, -n_discard);
248
+
249
+ for (size_t i = parent_ctx->params.n_keep + 1 + n_discard; i < embd.size(); i++)
250
+ {
251
+ embd[i - n_discard] = embd[i];
252
+ }
253
+ embd.resize(embd.size() - n_discard);
254
+
255
+ n_past -= n_discard;
256
+ truncated = true;
257
+
258
+ LOG_VERBOSE("context shifted, new n_past: %d, new size: %d", n_past, embd.size());
259
+ }
260
+
261
+ bool tg = true;
262
+ while (n_past < embd.size())
263
+ {
264
+ int n_eval = (int)embd.size() - n_past;
265
+ tg = n_eval == 1;
266
+ if (n_eval > parent_ctx->params.n_batch)
267
+ {
268
+ n_eval = parent_ctx->params.n_batch;
269
+ }
270
+ if (llama_decode(parent_ctx->ctx, llama_batch_get_one(&embd[n_past], n_eval)))
271
+ {
272
+ LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s",
273
+ n_eval,
274
+ n_past,
275
+ parent_ctx->params.cpuparams.n_threads,
276
+ tokens_to_str(parent_ctx->ctx, embd.cbegin() + n_past, embd.cend()).c_str()
277
+ );
278
+ has_next_token = false;
279
+ return result;
280
+ }
281
+ n_past += n_eval;
282
+
283
+ if(is_interrupted) {
284
+ LOG_INFO("Decoding Interrupted");
285
+ embd.resize(n_past);
286
+ has_next_token = false;
287
+ return result;
288
+ }
289
+ }
290
+
291
+ const llama_vocab* vocab = llama_model_get_vocab(parent_ctx->model);
292
+
293
+ if (parent_ctx->params.n_predict == 0)
294
+ {
295
+ has_next_token = false;
296
+ result.tok = llama_vocab_eos(vocab);
297
+ return result;
298
+ }
299
+
300
+ {
301
+ // out of user input, sample next token
302
+ std::vector<llama_token_data> candidates;
303
+ candidates.reserve(llama_vocab_n_tokens(vocab));
304
+
305
+ llama_token new_token_id = common_sampler_sample(ctx_sampling, parent_ctx->ctx, -1);
306
+
307
+ if (llama_vocab_is_eog(vocab, new_token_id)) {
308
+ has_next_token = false;
309
+ stopped_eos = true;
310
+ LOG_VERBOSE("EOS: %s", common_token_to_piece(parent_ctx->ctx, new_token_id).c_str());
311
+ return result;
312
+ }
313
+
314
+ if (parent_ctx->tts_wrapper != nullptr && parent_ctx->tts_wrapper->next_token_uses_guide_token && !parent_ctx->tts_wrapper->guide_tokens.empty() && !llama_vocab_is_control(vocab, new_token_id)) {
315
+ new_token_id = parent_ctx->tts_wrapper->guide_tokens[0];
316
+ parent_ctx->tts_wrapper->guide_tokens.erase(parent_ctx->tts_wrapper->guide_tokens.begin());
317
+ }
318
+ if (parent_ctx->tts_wrapper != nullptr) {
319
+ parent_ctx->tts_wrapper->next_token_uses_guide_token = (new_token_id == 198);
320
+ }
321
+ result.tok = new_token_id;
322
+
323
+ llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling);
324
+
325
+ const int32_t n_probs = parent_ctx->params.sampling.n_probs;
326
+
327
+ for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
328
+ {
329
+ result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
330
+ }
331
+
332
+ common_sampler_accept(ctx_sampling, result.tok, true);
333
+ if (tg) {
334
+ num_tokens_predicted++;
335
+ }
336
+ }
337
+
338
+ // add it to the context
339
+ embd.push_back(result.tok);
340
+ // decrement remaining sampling budget
341
+ --n_remain;
342
+
343
+ has_next_token = parent_ctx->params.n_predict == -1 || n_remain != 0;
344
+ return result;
345
+ }
346
+
347
+ size_t llama_rn_context_completion::findStoppingStrings(const std::string &text, const size_t last_token_size,
348
+ const stop_type type)
349
+ {
350
+ size_t stop_pos = std::string::npos;
351
+ for (const std::string &word : parent_ctx->params.antiprompt)
352
+ {
353
+ size_t pos;
354
+ if (type == STOP_FULL)
355
+ {
356
+ const size_t tmp = word.size() + last_token_size;
357
+ const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
358
+ pos = text.find(word, from_pos);
359
+ }
360
+ else
361
+ {
362
+ pos = find_partial_stop_string(word, text);
363
+ }
364
+ if (pos != std::string::npos &&
365
+ (stop_pos == std::string::npos || pos < stop_pos))
366
+ {
367
+ if (type == STOP_FULL)
368
+ {
369
+ stopping_word = word;
370
+ stopped_word = true;
371
+ has_next_token = false;
372
+ }
373
+ stop_pos = pos;
374
+ }
375
+ }
376
+ return stop_pos;
377
+ }
378
+
379
+ completion_token_output llama_rn_context_completion::doCompletion()
380
+ {
381
+ completion_token_output token_with_probs = nextToken();
382
+
383
+ const std::string token_text = token_with_probs.tok == -1 ? "" : common_token_to_piece(parent_ctx->ctx, token_with_probs.tok);
384
+ generated_text += token_text;
385
+
386
+ if (parent_ctx->isVocoderEnabled()) {
387
+ tts_type type = parent_ctx->tts_wrapper->getTTSType(parent_ctx);
388
+ if (parent_ctx->tts_wrapper->type == UNKNOWN) {
389
+ parent_ctx->tts_wrapper->type = type;
390
+ }
391
+ if ((type == OUTETTS_V0_2 || type == OUTETTS_V0_3) && (token_with_probs.tok >= 151672 && token_with_probs.tok <= 155772)) {
392
+ parent_ctx->tts_wrapper->audio_tokens.push_back(token_with_probs.tok);
393
+ }
394
+ }
395
+
396
+ if (parent_ctx->params.sampling.n_probs > 0)
397
+ {
398
+ generated_token_probs.push_back(token_with_probs);
399
+ }
400
+
401
+ // check if there is incomplete UTF-8 character at the end
402
+ for (unsigned i = 1; i < 5 && i <= generated_text.size(); ++i) {
403
+ unsigned char c = generated_text[generated_text.size() - i];
404
+ if ((c & 0xC0) == 0x80) {
405
+ // continuation byte: 10xxxxxx
406
+ continue;
407
+ }
408
+ if ((c & 0xE0) == 0xC0) {
409
+ // 2-byte character: 110xxxxx ...
410
+ incomplete = i < 2;
411
+ } else if ((c & 0xF0) == 0xE0) {
412
+ // 3-byte character: 1110xxxx ...
413
+ incomplete = i < 3;
414
+ } else if ((c & 0xF8) == 0xF0) {
415
+ // 4-byte character: 11110xxx ...
416
+ incomplete = i < 4;
417
+ }
418
+ // else 1-byte character or invalid byte
419
+ break;
420
+ }
421
+
422
+ if (incomplete && !has_next_token)
423
+ {
424
+ has_next_token = true;
425
+ n_remain++;
426
+ }
427
+
428
+ if (!has_next_token && n_remain == 0)
429
+ {
430
+ stopped_limit = true;
431
+ }
432
+
433
+ LOG_VERBOSE("next token, token: %s, token_text: %s, has_next_token: %d, n_remain: %d, num_tokens_predicted: %d, stopped_eos: %d, stopped_word: %d, stopped_limit: %d, stopping_word: %s",
434
+ common_token_to_piece(parent_ctx->ctx, token_with_probs.tok),
435
+ tokens_to_output_formatted_string(parent_ctx->ctx, token_with_probs.tok).c_str(),
436
+ has_next_token,
437
+ n_remain,
438
+ num_tokens_predicted,
439
+ stopped_eos,
440
+ stopped_word,
441
+ stopped_limit,
442
+ stopping_word.c_str()
443
+ );
444
+ return token_with_probs;
445
+ }
446
+
447
+ completion_partial_output llama_rn_context_completion::getPartialOutput(const std::string &token_text) {
448
+ common_chat_syntax syntax;
449
+ syntax.format = static_cast<common_chat_format>(current_chat_format);
450
+ syntax.reasoning_format = current_reasoning_format;
451
+ syntax.thinking_forced_open = current_thinking_forced_open;
452
+ syntax.parse_tool_calls = true;
453
+
454
+ common_chat_msg parsed_msg = common_chat_parse(prefill_text + generated_text, true, syntax);
455
+
456
+ completion_partial_output result;
457
+
458
+ result.content = parsed_msg.content;
459
+ result.reasoning_content = parsed_msg.reasoning_content;
460
+ result.accumulated_text = prefill_text + generated_text;
461
+ result.tool_calls = parsed_msg.tool_calls;
462
+
463
+ return result;
464
+ }
465
+
466
+ std::vector<float> llama_rn_context_completion::getEmbedding(common_params &embd_params)
467
+ {
468
+ static const int n_embd = llama_model_n_embd(llama_get_model(parent_ctx->ctx));
469
+ if (!embd_params.embedding)
470
+ {
471
+ LOG_WARNING("embedding disabled, embedding: %s", embd_params.embedding);
472
+ return std::vector<float>(n_embd, 0.0f);
473
+ }
474
+ float *data;
475
+
476
+ const enum llama_pooling_type pooling_type = llama_pooling_type(parent_ctx->ctx);
477
+ printf("pooling_type: %d\n", pooling_type);
478
+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
479
+ data = llama_get_embeddings(parent_ctx->ctx);
480
+ } else {
481
+ data = llama_get_embeddings_seq(parent_ctx->ctx, 0);
482
+ }
483
+
484
+ if (!data) {
485
+ return std::vector<float>(n_embd, 0.0f);
486
+ }
487
+ std::vector<float> embedding(data, data + n_embd), out(data, data + n_embd);
488
+ common_embd_normalize(embedding.data(), out.data(), n_embd, embd_params.embd_normalize);
489
+ return out;
490
+ }
491
+
492
+ std::vector<float> llama_rn_context_completion::rerank(const std::string &query, const std::vector<std::string> &documents)
493
+ {
494
+ std::vector<float> scores;
495
+
496
+ // Check if this model supports reranking (requires rank pooling type)
497
+ const enum llama_pooling_type pooling_type = llama_pooling_type(parent_ctx->ctx);
498
+ if (pooling_type != LLAMA_POOLING_TYPE_RANK) {
499
+ throw std::runtime_error("reranking not supported, pooling_type: " + std::to_string(pooling_type));
500
+ }
501
+
502
+ if (!parent_ctx->params.embedding) {
503
+ throw std::runtime_error("embedding disabled but required for reranking");
504
+ }
505
+
506
+ const llama_vocab * vocab = llama_model_get_vocab(parent_ctx->model);
507
+ std::vector<llama_token> query_tokens = common_tokenize(vocab, query, false, true);
508
+
509
+ scores.reserve(documents.size());
510
+
511
+ for (size_t i = 0; i < documents.size(); ++i) {
512
+ rewind();
513
+ embd = {};
514
+
515
+ const std::string & document = documents[i];
516
+
517
+ std::vector<llama_token> doc_tokens = common_tokenize(vocab, document, false, true);
518
+
519
+ std::vector<llama_token> rerank_tokens = format_rerank(vocab, query_tokens, doc_tokens);
520
+
521
+ llama_memory_clear(llama_get_memory(parent_ctx->ctx), false);
522
+
523
+ // Process the rerank input
524
+ try {
525
+ parent_ctx->params.prompt = tokens_to_str(parent_ctx->ctx, rerank_tokens.begin(), rerank_tokens.end());
526
+ initSampling();
527
+ loadPrompt({}); // No media paths for rerank
528
+ beginCompletion();
529
+ doCompletion();
530
+
531
+ // Get the rerank score (single embedding value for rank pooling)
532
+ float *data = llama_get_embeddings_seq(parent_ctx->ctx, 0);
533
+ if (data) {
534
+ scores.push_back(data[0]); // For rank pooling, the score is the first (and only) dimension
535
+ } else {
536
+ scores.push_back(-1e6f); // Default low score if computation failed
537
+ }
538
+ } catch (const std::exception &e) {
539
+ LOG_WARNING("rerank computation failed for document %zu: %s", i, e.what());
540
+ scores.push_back(-1e6f);
541
+ }
542
+ endCompletion();
543
+
544
+ // Clear KV cache again to prepare for next document or restore original state
545
+ llama_memory_clear(llama_get_memory(parent_ctx->ctx), false);
546
+ }
547
+
548
+ return scores;
549
+ }
550
+
551
+ std::string llama_rn_context_completion::bench(int pp, int tg, int pl, int nr)
552
+ {
553
+ if (is_predicting) {
554
+ LOG_ERROR("cannot benchmark while predicting", "");
555
+ return std::string("[]");
556
+ }
557
+
558
+ is_predicting = true;
559
+
560
+ double pp_avg = 0;
561
+ double tg_avg = 0;
562
+
563
+ double pp_std = 0;
564
+ double tg_std = 0;
565
+
566
+ // TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30)
567
+ llama_batch batch = llama_batch_init(
568
+ std::min(pp, parent_ctx->params.n_ubatch), // max n_tokens is limited by n_ubatch
569
+ 0, // No embeddings
570
+ 1 // Single sequence
571
+ );
572
+
573
+ for (int i = 0; i < nr; i++)
574
+ {
575
+ llama_batch_clear(&batch);
576
+
577
+ const int n_tokens = pp;
578
+
579
+ for (int i = 0; i < n_tokens; i++)
580
+ {
581
+ llama_batch_add(&batch, 0, i, {0}, false);
582
+ }
583
+ batch.logits[batch.n_tokens - 1] = 1; // true
584
+
585
+ llama_memory_clear(llama_get_memory(parent_ctx->ctx), true);
586
+
587
+ const int64_t t_pp_start = llama_time_us();
588
+ if (llama_decode(parent_ctx->ctx, batch) != 0)
589
+ {
590
+ LOG_ERROR("llama_decode() failed during prompt", "");
591
+ }
592
+ const int64_t t_pp_end = llama_time_us();
593
+
594
+ llama_memory_clear(llama_get_memory(parent_ctx->ctx), true);
595
+
596
+ if (is_interrupted) break;
597
+
598
+ const int64_t t_tg_start = llama_time_us();
599
+
600
+ for (int i = 0; i < tg; i++)
601
+ {
602
+ llama_batch_clear(&batch);
603
+
604
+ for (int j = 0; j < pl; j++)
605
+ {
606
+ llama_batch_add(&batch, 0, i, {j}, true);
607
+ }
608
+
609
+ if (llama_decode(parent_ctx->ctx, batch) != 0)
610
+ {
611
+ LOG_ERROR("llama_decode() failed during text generation", "");
612
+ }
613
+ if (is_interrupted) break;
614
+ }
615
+
616
+ const int64_t t_tg_end = llama_time_us();
617
+
618
+ llama_memory_clear(llama_get_memory(parent_ctx->ctx), true);
619
+
620
+ const double t_pp = (t_pp_end - t_pp_start) / 1000000.0;
621
+ const double t_tg = (t_tg_end - t_tg_start) / 1000000.0;
622
+
623
+ const double speed_pp = pp / t_pp;
624
+ const double speed_tg = (pl * tg) / t_tg;
625
+
626
+ pp_avg += speed_pp;
627
+ tg_avg += speed_tg;
628
+
629
+ pp_std += speed_pp * speed_pp;
630
+ tg_std += speed_tg * speed_tg;
631
+ }
632
+
633
+ pp_avg /= nr;
634
+ tg_avg /= nr;
635
+
636
+ if (nr > 1) {
637
+ pp_std = sqrt(pp_std / (nr - 1) - pp_avg * pp_avg * nr / (nr - 1));
638
+ tg_std = sqrt(tg_std / (nr - 1) - tg_avg * tg_avg * nr / (nr - 1));
639
+ } else {
640
+ pp_std = 0;
641
+ tg_std = 0;
642
+ }
643
+
644
+ if (is_interrupted) llama_memory_clear(llama_get_memory(parent_ctx->ctx), true);
645
+ endCompletion();
646
+
647
+ char model_desc[128];
648
+ llama_model_desc(parent_ctx->model, model_desc, sizeof(model_desc));
649
+ return std::string("[\"") + model_desc + std::string("\",") +
650
+ std::to_string(llama_model_size(parent_ctx->model)) + std::string(",") +
651
+ std::to_string(llama_model_n_params(parent_ctx->model)) + std::string(",") +
652
+ std::to_string(pp_avg) + std::string(",") +
653
+ std::to_string(pp_std) + std::string(",") +
654
+ std::to_string(tg_avg) + std::string(",") +
655
+ std::to_string(tg_std) +
656
+ std::string("]");
657
+ }
658
+
659
+ void llama_rn_context_completion::processMedia(
660
+ const std::string &prompt,
661
+ const std::vector<std::string> &media_paths
662
+ ) {
663
+ if (!parent_ctx->isMultimodalEnabled()) {
664
+ throw std::runtime_error("Multimodal is not enabled but image paths are provided");
665
+ }
666
+
667
+ // Delegate to the mtmd_wrapper method
668
+ parent_ctx->mtmd_wrapper->processMedia(
669
+ parent_ctx->ctx,
670
+ prompt,
671
+ media_paths,
672
+ parent_ctx->n_ctx,
673
+ parent_ctx->params.n_batch,
674
+ n_past,
675
+ embd,
676
+ context_full,
677
+ ctx_sampling
678
+ );
679
+ }
680
+
681
+ } // namespace rnllama