@fugood/llama.node 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/CMakeLists.txt +40 -5
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +46 -0
  17. package/lib/index.js +18 -0
  18. package/lib/index.ts +24 -0
  19. package/package.json +4 -1
  20. package/patches/node-api-headers+1.1.0.patch +26 -0
  21. package/src/DecodeAudioTokenWorker.cpp +40 -0
  22. package/src/DecodeAudioTokenWorker.h +22 -0
  23. package/src/EmbeddingWorker.cpp +7 -5
  24. package/src/LlamaCompletionWorker.cpp +64 -50
  25. package/src/LlamaCompletionWorker.h +6 -7
  26. package/src/LlamaContext.cpp +523 -224
  27. package/src/LlamaContext.h +25 -4
  28. package/src/LoadSessionWorker.cpp +4 -2
  29. package/src/SaveSessionWorker.cpp +10 -6
  30. package/src/TokenizeWorker.cpp +10 -5
  31. package/src/addons.cc +8 -11
  32. package/src/common.hpp +92 -93
  33. package/src/tts_utils.cpp +346 -0
  34. package/src/tts_utils.h +62 -0
  35. package/src/win_dynamic_load.c +2102 -0
  36. package/bin/win32/arm64/llama-node.node +0 -0
  37. package/bin/win32/arm64/node.lib +0 -0
@@ -1,15 +1,22 @@
1
1
  #include "common.hpp"
2
- #include "tools/mtmd/mtmd.h"
3
2
  #include "tools/mtmd/clip.h"
3
+ #include "tools/mtmd/mtmd.h"
4
+ #include "tts_utils.h"
4
5
 
5
6
  class LlamaCompletionWorker;
6
7
 
8
+ struct vocoder_context {
9
+ common_params params;
10
+ std::shared_ptr<llama_model> model;
11
+ std::shared_ptr<llama_context> context;
12
+ };
13
+
7
14
  class LlamaContext : public Napi::ObjectWrap<LlamaContext> {
8
15
  public:
9
16
  LlamaContext(const Napi::CallbackInfo &info);
10
17
  ~LlamaContext();
11
18
  static void ToggleNativeLog(const Napi::CallbackInfo &info);
12
- static Napi::Value ModelInfo(const Napi::CallbackInfo& info);
19
+ static Napi::Value ModelInfo(const Napi::CallbackInfo &info);
13
20
  static void Init(Napi::Env env, Napi::Object &exports);
14
21
 
15
22
  private:
@@ -27,21 +34,35 @@ private:
27
34
  void RemoveLoraAdapters(const Napi::CallbackInfo &info);
28
35
  Napi::Value GetLoadedLoraAdapters(const Napi::CallbackInfo &info);
29
36
  Napi::Value Release(const Napi::CallbackInfo &info);
30
-
37
+
31
38
  // Multimodal methods
32
39
  Napi::Value InitMultimodal(const Napi::CallbackInfo &info);
33
40
  Napi::Value IsMultimodalEnabled(const Napi::CallbackInfo &info);
34
41
  Napi::Value GetMultimodalSupport(const Napi::CallbackInfo &info);
35
42
  void ReleaseMultimodal(const Napi::CallbackInfo &info);
36
43
 
44
+ // TTS methods
45
+ tts_type getTTSType(Napi::Env env, nlohmann::json speaker = nullptr);
46
+ Napi::Value InitVocoder(const Napi::CallbackInfo &info);
47
+ void ReleaseVocoder(const Napi::CallbackInfo &info);
48
+ Napi::Value IsVocoderEnabled(const Napi::CallbackInfo &info);
49
+ Napi::Value GetFormattedAudioCompletion(const Napi::CallbackInfo &info);
50
+ Napi::Value GetAudioCompletionGuideTokens(const Napi::CallbackInfo &info);
51
+ Napi::Value DecodeAudioTokens(const Napi::CallbackInfo &info);
52
+
37
53
  std::string _info;
38
54
  Napi::Object _meta;
39
55
  LlamaSessionPtr _sess = nullptr;
40
56
  common_chat_templates_ptr _templates;
41
57
  std::vector<common_adapter_lora_info> _lora;
42
58
  LlamaCompletionWorker *_wip = nullptr;
43
-
59
+
44
60
  // Multimodal support
45
61
  mtmd_context *_mtmd_ctx = nullptr;
46
62
  bool _has_multimodal = false;
63
+
64
+ // Vocoder support
65
+ tts_type _tts_type = UNKNOWN;
66
+ vocoder_context _vocoder;
67
+ bool _has_vocoder = false;
47
68
  };
@@ -12,8 +12,10 @@ void LoadSessionWorker::Execute() {
12
12
  std::vector<llama_token> tokens;
13
13
  tokens.reserve(_sess->params().n_ctx);
14
14
 
15
- // Find LLAMA_TOKEN_NULL in the tokens and resize the array to the index of the null token
16
- auto null_token_iter = std::find(tokens.begin(), tokens.end(), LLAMA_TOKEN_NULL);
15
+ // Find LLAMA_TOKEN_NULL in the tokens and resize the array to the index of
16
+ // the null token
17
+ auto null_token_iter =
18
+ std::find(tokens.begin(), tokens.end(), LLAMA_TOKEN_NULL);
17
19
  if (null_token_iter != tokens.end()) {
18
20
  tokens.resize(std::distance(tokens.begin(), null_token_iter));
19
21
  }
@@ -9,16 +9,20 @@ SaveSessionWorker::SaveSessionWorker(const Napi::CallbackInfo &info,
9
9
  void SaveSessionWorker::Execute() {
10
10
  _sess->get_mutex().lock();
11
11
  auto tokens = _sess->tokens_ptr();
12
- auto tokens_to_save = std::vector<llama_token>(tokens->begin(), tokens->end());
12
+ auto tokens_to_save =
13
+ std::vector<llama_token>(tokens->begin(), tokens->end());
13
14
 
14
- // Find LLAMA_TOKEN_NULL in the tokens and resize the array to the index of the null token
15
- auto null_token_iter = std::find(tokens_to_save.begin(), tokens_to_save.end(), LLAMA_TOKEN_NULL);
15
+ // Find LLAMA_TOKEN_NULL in the tokens and resize the array to the index of
16
+ // the null token
17
+ auto null_token_iter =
18
+ std::find(tokens_to_save.begin(), tokens_to_save.end(), LLAMA_TOKEN_NULL);
16
19
  if (null_token_iter != tokens_to_save.end()) {
17
- tokens_to_save.resize(std::distance(tokens_to_save.begin(), null_token_iter));
20
+ tokens_to_save.resize(
21
+ std::distance(tokens_to_save.begin(), null_token_iter));
18
22
  }
19
23
 
20
- if (!llama_state_save_file(_sess->context(), _path.c_str(), tokens_to_save.data(),
21
- tokens_to_save.size())) {
24
+ if (!llama_state_save_file(_sess->context(), _path.c_str(),
25
+ tokens_to_save.data(), tokens_to_save.size())) {
22
26
  SetError("Failed to save session");
23
27
  }
24
28
  _sess->get_mutex().unlock();
@@ -2,8 +2,10 @@
2
2
  #include "LlamaContext.h"
3
3
 
4
4
  TokenizeWorker::TokenizeWorker(const Napi::CallbackInfo &info,
5
- LlamaSessionPtr &sess, std::string text, std::vector<std::string> media_paths)
6
- : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text), _media_paths(media_paths) {}
5
+ LlamaSessionPtr &sess, std::string text,
6
+ std::vector<std::string> media_paths)
7
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text),
8
+ _media_paths(media_paths) {}
7
9
 
8
10
  void TokenizeWorker::Execute() {
9
11
  auto mtmd_ctx = _sess->get_mtmd_ctx();
@@ -31,17 +33,20 @@ void TokenizeWorker::OnOK() {
31
33
  result.Set("tokens", tokens);
32
34
  result.Set("has_media", _result.has_media);
33
35
  if (_result.has_media) {
34
- auto bitmap_hashes = Napi::Array::New(Napi::AsyncWorker::Env(), _result.bitmap_hashes.size());
36
+ auto bitmap_hashes = Napi::Array::New(Napi::AsyncWorker::Env(),
37
+ _result.bitmap_hashes.size());
35
38
  for (size_t i = 0; i < _result.bitmap_hashes.size(); i++) {
36
39
  bitmap_hashes.Set(i, _result.bitmap_hashes[i]);
37
40
  }
38
41
  result.Set("bitmap_hashes", bitmap_hashes);
39
- auto chunk_pos = Napi::Array::New(Napi::AsyncWorker::Env(), _result.chunk_pos.size());
42
+ auto chunk_pos =
43
+ Napi::Array::New(Napi::AsyncWorker::Env(), _result.chunk_pos.size());
40
44
  for (size_t i = 0; i < _result.chunk_pos.size(); i++) {
41
45
  chunk_pos.Set(i, _result.chunk_pos[i]);
42
46
  }
43
47
  result.Set("chunk_pos", chunk_pos);
44
- auto chunk_pos_media = Napi::Array::New(Napi::AsyncWorker::Env(), _result.chunk_pos_media.size());
48
+ auto chunk_pos_media = Napi::Array::New(Napi::AsyncWorker::Env(),
49
+ _result.chunk_pos_media.size());
45
50
  for (size_t i = 0; i < _result.chunk_pos_media.size(); i++) {
46
51
  chunk_pos_media.Set(i, _result.chunk_pos_media[i]);
47
52
  }
package/src/addons.cc CHANGED
@@ -5,25 +5,22 @@
5
5
  extern "C" void cleanup_logging();
6
6
 
7
7
  // Register cleanup function on module unload
8
- static Napi::Value register_cleanup(const Napi::CallbackInfo& info) {
9
- napi_add_env_cleanup_hook(info.Env(), [](void*) {
10
- cleanup_logging();
11
- }, nullptr);
12
-
8
+ static Napi::Value register_cleanup(const Napi::CallbackInfo &info) {
9
+ napi_add_env_cleanup_hook(
10
+ info.Env(), [](void *) { cleanup_logging(); }, nullptr);
11
+
13
12
  return info.Env().Undefined();
14
13
  }
15
14
 
16
15
  Napi::Object Init(Napi::Env env, Napi::Object exports) {
17
16
  LlamaContext::Init(env, exports);
18
-
17
+
19
18
  // Register our cleanup handler for module unload
20
19
  exports.Set("__registerCleanup", Napi::Function::New(env, register_cleanup));
21
-
20
+
22
21
  // Also register cleanup directly on module init
23
- napi_add_env_cleanup_hook(env, [](void*) {
24
- cleanup_logging();
25
- }, nullptr);
26
-
22
+ napi_add_env_cleanup_hook(env, [](void *) { cleanup_logging(); }, nullptr);
23
+
27
24
  return exports;
28
25
  }
29
26
 
package/src/common.hpp CHANGED
@@ -1,11 +1,10 @@
1
1
  #pragma once
2
2
 
3
+ #include "chat.h"
3
4
  #include "common/common.h"
4
5
  #include "common/sampling.h"
5
- #include "tools/mtmd/mtmd.h"
6
- #include "tools/mtmd/clip.h"
7
- #include "chat.h"
8
6
  #include "llama.h"
7
+ #include "tools/mtmd/clip.h"
9
8
  #include "tools/mtmd/mtmd.h"
10
9
  #include <memory>
11
10
  #include <mutex>
@@ -27,13 +26,17 @@ static std::string json_stringify(const Napi::Object &obj) {
27
26
  Napi::Env env = obj.Env();
28
27
  Napi::Object json = env.Global().Get("JSON").As<Napi::Object>();
29
28
  Napi::Function stringify = json.Get("stringify").As<Napi::Function>();
30
- return stringify.Call(json, { obj }).As<Napi::String>().ToString();
29
+ return stringify.Call(json, {obj}).As<Napi::String>().ToString();
31
30
  }
32
31
 
33
- static void console_log(Napi::Env env, const std::string& message) {
34
- Napi::Function consoleLog = env.Global().Get("console").As<Napi::Object>().Get("log").As<Napi::Function>();
35
- consoleLog.Call({ Napi::String::New(env, message) });
36
- }
32
+ static void console_log(Napi::Env env, const std::string &message) {
33
+ Napi::Function consoleLog = env.Global()
34
+ .Get("console")
35
+ .As<Napi::Object>()
36
+ .Get("log")
37
+ .As<Napi::Function>();
38
+ consoleLog.Call({Napi::String::New(env, message)});
39
+ }
37
40
 
38
41
  template <typename T>
39
42
  constexpr T get_option(const Napi::Object &options, const std::string &name,
@@ -64,8 +67,7 @@ constexpr T get_option(const Napi::Object &options, const std::string &name,
64
67
 
65
68
  class LlamaSession {
66
69
  public:
67
- LlamaSession(common_params params)
68
- : params_(params) {
70
+ LlamaSession(common_params params) : params_(params) {
69
71
  llama_init_ = common_init_from_params(params);
70
72
  tokens_.reserve(params.n_ctx);
71
73
  }
@@ -93,21 +95,17 @@ public:
93
95
  inline const common_params &params() const { return params_; }
94
96
 
95
97
  inline std::mutex &get_mutex() { return mutex; }
96
-
98
+
97
99
  // Getter for the multimodal context
98
- inline const mtmd_context* get_mtmd_ctx() const {
99
- return _mtmd_ctx;
100
- }
101
-
100
+ inline const mtmd_context *get_mtmd_ctx() const { return _mtmd_ctx; }
101
+
102
102
  // Setter for the multimodal context
103
- inline void set_mtmd_ctx(mtmd_context* ctx) {
104
- _mtmd_ctx = ctx;
105
- }
103
+ inline void set_mtmd_ctx(mtmd_context *ctx) { _mtmd_ctx = ctx; }
106
104
 
107
105
  void dispose() {
108
106
  std::lock_guard<std::mutex> lock(mutex);
109
107
  tokens_.clear();
110
-
108
+
111
109
  // mtmd_ctx is owned by LlamaContext, so we don't free it here
112
110
  _mtmd_ctx = nullptr;
113
111
  }
@@ -118,13 +116,13 @@ private:
118
116
  std::vector<llama_token> tokens_{};
119
117
  std::vector<std::string> mtmd_bitmap_past_hashes_{};
120
118
  std::mutex mutex;
121
- mtmd_context* _mtmd_ctx = nullptr;
119
+ mtmd_context *_mtmd_ctx = nullptr;
122
120
  };
123
121
 
124
122
  typedef std::shared_ptr<LlamaSession> LlamaSessionPtr;
125
123
 
126
124
  static size_t common_tokens_part(const std::vector<llama_token> &a,
127
- const std::vector<llama_token> &b) {
125
+ const std::vector<llama_token> &b) {
128
126
  size_t i = 0;
129
127
  while (i < a.size() && i < b.size() && a[i] == b[i]) {
130
128
  i++;
@@ -133,7 +131,7 @@ static size_t common_tokens_part(const std::vector<llama_token> &a,
133
131
  }
134
132
 
135
133
  // Computes FNV-1a hash of the data
136
- static std::string fnv_hash(const uint8_t * data, size_t len) {
134
+ static std::string fnv_hash(const uint8_t *data, size_t len) {
137
135
  const uint64_t fnv_prime = 0x100000001b3ULL;
138
136
  uint64_t hash = 0xcbf29ce484222325ULL;
139
137
 
@@ -144,10 +142,9 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
144
142
  return std::to_string(hash);
145
143
  }
146
144
 
147
- static const std::string base64_chars =
148
- "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
149
- "abcdefghijklmnopqrstuvwxyz"
150
- "0123456789+/";
145
+ static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
146
+ "abcdefghijklmnopqrstuvwxyz"
147
+ "0123456789+/";
151
148
 
152
149
  // Base64 decoding function
153
150
  static std::vector<uint8_t> base64_decode(const std::string &encoded_string) {
@@ -164,18 +161,22 @@ static std::vector<uint8_t> base64_decode(const std::string &encoded_string) {
164
161
  continue;
165
162
  }
166
163
 
167
- if (encoded_string[in_] == '=' || base64_chars.find(encoded_string[in_]) == std::string::npos) {
164
+ if (encoded_string[in_] == '=' ||
165
+ base64_chars.find(encoded_string[in_]) == std::string::npos) {
168
166
  break;
169
167
  }
170
168
 
171
- char_array_4[i++] = encoded_string[in_]; in_++;
169
+ char_array_4[i++] = encoded_string[in_];
170
+ in_++;
172
171
  if (i == 4) {
173
172
  for (i = 0; i < 4; i++) {
174
173
  char_array_4[i] = base64_chars.find(char_array_4[i]);
175
174
  }
176
175
 
177
- char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
178
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
176
+ char_array_3[0] =
177
+ (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
178
+ char_array_3[1] =
179
+ ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
179
180
  char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
180
181
 
181
182
  for (i = 0; i < 3; i++) {
@@ -195,7 +196,8 @@ static std::vector<uint8_t> base64_decode(const std::string &encoded_string) {
195
196
  }
196
197
 
197
198
  char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
198
- char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
199
+ char_array_3[1] =
200
+ ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
199
201
  char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
200
202
 
201
203
  for (j = 0; j < i - 1; j++) {
@@ -211,16 +213,14 @@ struct TokenizeResult {
211
213
 
212
214
  bool has_media = false;
213
215
  std::vector<std::string> bitmap_hashes;
214
- std::vector<size_t> chunk_pos; // both text and media
216
+ std::vector<size_t> chunk_pos; // both text and media
215
217
  std::vector<size_t> chunk_pos_media; // media only
216
- mtmd_input_chunks* chunks = nullptr;
218
+ mtmd_input_chunks *chunks = nullptr;
217
219
  };
218
220
 
219
- static TokenizeResult tokenizeWithMedia(
220
- const mtmd_context* mtmd_ctx,
221
- const std::string &prompt,
222
- const std::vector<std::string> &media_paths
223
- ) {
221
+ static TokenizeResult
222
+ tokenizeWithMedia(const mtmd_context *mtmd_ctx, const std::string &prompt,
223
+ const std::vector<std::string> &media_paths) {
224
224
  if (mtmd_ctx == nullptr) {
225
225
  throw std::runtime_error("Multimodal context is not initialized");
226
226
  }
@@ -231,19 +231,22 @@ static TokenizeResult tokenizeWithMedia(
231
231
  mtmd::bitmaps bitmaps;
232
232
 
233
233
  // Load all media paths
234
- for (const auto& media_path : media_paths) {
235
- fprintf(stdout, "[DEBUG] Loading media: %s\n",
236
- media_path.substr(0, 50).c_str()); // Only log part of path for base64
234
+ for (const auto &media_path : media_paths) {
235
+ fprintf(
236
+ stdout, "[DEBUG] Loading media: %s\n",
237
+ media_path.substr(0, 50).c_str()); // Only log part of path for base64
237
238
 
238
239
  // Check if it's a base64 media
239
- if (media_path.compare(0, 11, "data:image/") == 0 || media_path.compare(0, 11, "data:audio/") == 0) {
240
+ if (media_path.compare(0, 11, "data:image/") == 0 ||
241
+ media_path.compare(0, 11, "data:audio/") == 0) {
240
242
 
241
243
  // Parse base64 data
242
244
  std::vector<std::string> parts;
243
245
  size_t comma_pos = media_path.find(',');
244
246
  if (comma_pos == std::string::npos) {
245
247
  result.bitmap_hashes.clear();
246
- throw std::runtime_error("Invalid base64 media format, missing comma separator");
248
+ throw std::runtime_error(
249
+ "Invalid base64 media format, missing comma separator");
247
250
  }
248
251
 
249
252
  std::string header = media_path.substr(0, comma_pos);
@@ -260,7 +263,8 @@ static TokenizeResult tokenizeWithMedia(
260
263
  std::vector<uint8_t> media_data = base64_decode(base64_data);
261
264
 
262
265
  // Load bitmap from memory buffer using direct initialization
263
- mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(media_data.data(), media_data.size()));
266
+ mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(media_data.data(),
267
+ media_data.size()));
264
268
  if (!bmp.ptr) {
265
269
  bitmaps.entries.clear();
266
270
  throw std::runtime_error("Failed to load base64 media");
@@ -271,18 +275,19 @@ static TokenizeResult tokenizeWithMedia(
271
275
  bmp.set_id(hash.c_str());
272
276
  bitmaps.entries.push_back(std::move(bmp));
273
277
  result.bitmap_hashes.push_back(hash.c_str());
274
- } catch (const std::exception& e) {
278
+ } catch (const std::exception &e) {
275
279
  bitmaps.entries.clear();
276
280
  throw std::runtime_error("Failed to decode base64 media");
277
281
  }
278
- } else if (media_path.compare(0, 7, "http://") == 0 || media_path.compare(0, 8, "https://") == 0) {
282
+ } else if (media_path.compare(0, 7, "http://") == 0 ||
283
+ media_path.compare(0, 8, "https://") == 0) {
279
284
  // HTTP URLs are not supported yet
280
285
  bitmaps.entries.clear();
281
286
  throw std::runtime_error("HTTP/HTTPS URLs are not supported yet");
282
287
  } else {
283
288
  // Regular file path
284
289
  // Check if file exists
285
- FILE* file = fopen(media_path.c_str(), "rb");
290
+ FILE *file = fopen(media_path.c_str(), "rb");
286
291
  if (file == nullptr) {
287
292
  bitmaps.entries.clear();
288
293
  throw std::runtime_error("File does not exist or cannot be opened");
@@ -302,7 +307,7 @@ static TokenizeResult tokenizeWithMedia(
302
307
  }
303
308
 
304
309
  // Calculate bitmap hash (for KV caching)
305
- std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
310
+ std::string hash = fnv_hash(bmp.data(), bmp.nx() * bmp.ny() * 3);
306
311
  bmp.set_id(hash.c_str());
307
312
  bitmaps.entries.push_back(std::move(bmp));
308
313
  result.bitmap_hashes.push_back(hash.c_str());
@@ -314,26 +319,23 @@ static TokenizeResult tokenizeWithMedia(
314
319
  bitmaps.entries.clear();
315
320
  throw std::runtime_error("Failed to initialize input chunks");
316
321
  }
317
-
322
+
318
323
  // Create input text
319
324
  mtmd_input_text input_text;
320
325
  input_text.text = prompt.c_str(); // Use the full prompt with media marker
321
- input_text.add_special = true; // Add BOS token if this is the first message
322
- input_text.parse_special = true; // Parse special tokens like <__media__>
326
+ input_text.add_special = true; // Add BOS token if this is the first message
327
+ input_text.parse_special = true; // Parse special tokens like <__media__>
323
328
 
324
329
  // Tokenize the text and media
325
- fprintf(stdout, "[DEBUG] Tokenizing text and %zu media\n", bitmaps.entries.size());
330
+ fprintf(stdout, "[DEBUG] Tokenizing text and %zu media\n",
331
+ bitmaps.entries.size());
326
332
  auto bitmaps_c_ptr = bitmaps.c_ptr();
327
-
333
+
328
334
  // Cast away const for mtmd_tokenize
329
- int32_t res = mtmd_tokenize(
330
- const_cast<mtmd_context*>(mtmd_ctx),
331
- result.chunks,
332
- &input_text,
333
- bitmaps_c_ptr.data(),
334
- bitmaps_c_ptr.size()
335
- );
336
-
335
+ int32_t res =
336
+ mtmd_tokenize(const_cast<mtmd_context *>(mtmd_ctx), result.chunks,
337
+ &input_text, bitmaps_c_ptr.data(), bitmaps_c_ptr.size());
338
+
337
339
  if (res != 0) {
338
340
  mtmd_input_chunks_free(result.chunks);
339
341
  bitmaps.entries.clear();
@@ -342,7 +344,8 @@ static TokenizeResult tokenizeWithMedia(
342
344
 
343
345
  // Log chunk information
344
346
  size_t num_chunks = mtmd_input_chunks_size(result.chunks);
345
- fprintf(stdout, "[DEBUG] Tokenization successful: num_chunks=%zu\n", num_chunks);
347
+ fprintf(stdout, "[DEBUG] Tokenization successful: num_chunks=%zu\n",
348
+ num_chunks);
346
349
 
347
350
  // Track the total number of tokens (both text and media)
348
351
  size_t total_token_count = 0;
@@ -351,22 +354,25 @@ static TokenizeResult tokenizeWithMedia(
351
354
  for (size_t i = 0; i < num_chunks; i++) {
352
355
  result.chunk_pos.push_back(total_token_count);
353
356
 
354
- const mtmd_input_chunk* chunk = mtmd_input_chunks_get(result.chunks, i);
357
+ const mtmd_input_chunk *chunk = mtmd_input_chunks_get(result.chunks, i);
355
358
  mtmd_input_chunk_type chunk_type = mtmd_input_chunk_get_type(chunk);
356
359
 
357
360
  if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
358
361
  size_t n_tokens;
359
- const llama_token* tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
362
+ const llama_token *tokens =
363
+ mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
360
364
 
361
365
  result.tokens.insert(result.tokens.end(), tokens, tokens + n_tokens);
362
366
  total_token_count += n_tokens;
363
- } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
367
+ } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ||
368
+ chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
364
369
  result.chunk_pos_media.push_back(total_token_count);
365
370
 
366
371
  size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
367
372
  size_t n_pos = mtmd_input_chunk_get_n_pos(chunk);
368
373
  fprintf(stdout, "[DEBUG] Chunk %zu: type=%s, n_tokens=%zu, n_pos=%zu\n",
369
- i, chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "IMAGE" : "AUDIO", n_tokens, n_pos);
374
+ i, chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "IMAGE" : "AUDIO",
375
+ n_tokens, n_pos);
370
376
 
371
377
  for (size_t j = 0; j < n_pos; j++) {
372
378
  result.tokens.push_back(LLAMA_TOKEN_NULL);
@@ -376,18 +382,15 @@ static TokenizeResult tokenizeWithMedia(
376
382
  }
377
383
 
378
384
  bitmaps.entries.clear();
379
-
385
+
380
386
  return result;
381
387
  }
382
388
 
383
389
  // Process media and add them to the tokenized input
384
- static llama_pos processMediaPrompt(
385
- llama_context* ctx,
386
- const mtmd_context* mtmd_ctx,
387
- LlamaSessionPtr sess,
388
- const common_params& params,
389
- const std::vector<std::string>& media_paths
390
- ) {
390
+ static llama_pos
391
+ processMediaPrompt(llama_context *ctx, const mtmd_context *mtmd_ctx,
392
+ LlamaSessionPtr sess, const common_params &params,
393
+ const std::vector<std::string> &media_paths) {
391
394
  if (mtmd_ctx == nullptr) {
392
395
  throw std::runtime_error("Multimodal context is not initialized");
393
396
  }
@@ -422,11 +425,10 @@ static llama_pos processMediaPrompt(
422
425
  break;
423
426
  }
424
427
  bool is_end = i + 1 == chunk_pos.size();
425
- if (
426
- chunk_pos[i] < n_past &&
427
- (!is_end && chunk_pos[i + 1] > n_past)
428
- // is_end & n_past < total_token_count:
429
- // don't need to adjust and it will skip eval_chunk_single, let nextToken() to finish the job
428
+ if (chunk_pos[i] < n_past && (!is_end && chunk_pos[i + 1] > n_past)
429
+ // is_end & n_past < total_token_count:
430
+ // don't need to adjust and it will skip eval_chunk_single, let
431
+ // nextToken() to finish the job
430
432
  ) {
431
433
  adjusted_n_past = chunk_pos[i];
432
434
  }
@@ -437,7 +439,8 @@ static llama_pos processMediaPrompt(
437
439
  fprintf(stdout, "[DEBUG] Adjusted n_past to %d\n", n_past);
438
440
  }
439
441
 
440
- // Compare bitmap hashes, if they are not the same, backtrack n_past to the position of the first mismatch
442
+ // Compare bitmap hashes, if they are not the same, backtrack n_past to the
443
+ // position of the first mismatch
441
444
  auto mtmd_bitmap_past_hashes = sess->mtmd_bitmap_past_hashes_ptr();
442
445
  if (mtmd_bitmap_past_hashes->size() > 0) {
443
446
  for (size_t i = 0; i < bitmap_hashes.size(); i++) {
@@ -462,7 +465,8 @@ static llama_pos processMediaPrompt(
462
465
  size_t num_chunks = mtmd_input_chunks_size(chunks);
463
466
 
464
467
  for (size_t i = 0; i < chunk_pos.size(); i++) {
465
- fprintf(stdout, "[DEBUG] Evaluating chunk %zu: n_past=%d, chunk_pos=%zu\n", i, n_past, chunk_pos[i]);
468
+ fprintf(stdout, "[DEBUG] Evaluating chunk %zu: n_past=%d, chunk_pos=%zu\n",
469
+ i, n_past, chunk_pos[i]);
466
470
 
467
471
  // Process chunk only if it's after the current n_past
468
472
  if (chunk_pos[i] >= new_n_past) {
@@ -471,16 +475,10 @@ static llama_pos processMediaPrompt(
471
475
 
472
476
  // Cast away const for mtmd_helper_eval_chunk_single
473
477
  int32_t res = mtmd_helper_eval_chunk_single(
474
- const_cast<mtmd_context*>(mtmd_ctx),
475
- ctx,
476
- chunk,
477
- n_past,
478
- 0,
479
- params.n_batch, // batch size
480
- chunk_logits_last,
481
- &new_n_past
482
- );
483
-
478
+ const_cast<mtmd_context *>(mtmd_ctx), ctx, chunk, n_past, 0,
479
+ params.n_batch, // batch size
480
+ chunk_logits_last, &new_n_past);
481
+
484
482
  if (res != 0) {
485
483
  mtmd_input_chunks_free(chunks);
486
484
  throw std::runtime_error("Failed to process chunk");
@@ -489,13 +487,14 @@ static llama_pos processMediaPrompt(
489
487
  }
490
488
  }
491
489
 
492
- if (n_past == all_tokens.size() && n_past > 0 && all_tokens[n_past - 1] != LLAMA_TOKEN_NULL) {
490
+ if (n_past == all_tokens.size() && n_past > 0 &&
491
+ all_tokens[n_past - 1] != LLAMA_TOKEN_NULL) {
493
492
  // we have to evaluate at least 1 token to generate logits.
494
493
  n_past--;
495
494
  }
496
495
 
497
496
  // Update sampling context to process token sequences
498
- for (auto & token : all_tokens) {
497
+ for (auto &token : all_tokens) {
499
498
  if (token == LLAMA_TOKEN_NULL) {
500
499
  continue;
501
500
  }