@shipworthy/ai-sdk-llama-cpp 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,50 +1,49 @@
1
- #include <napi.h>
2
1
  #include "llama-wrapper.h"
2
+
3
+ #include <napi.h>
4
+
5
+ #include <atomic>
3
6
  #include <memory>
4
- #include <unordered_map>
5
7
  #include <mutex>
6
- #include <atomic>
8
+ #include <unordered_map>
7
9
 
8
10
  // Global state for managing models
9
11
  static std::unordered_map<int, std::unique_ptr<llama_wrapper::LlamaModel>> g_models;
10
- static std::mutex g_models_mutex;
11
- static std::atomic<int> g_next_handle{1};
12
+ static std::mutex g_models_mutex;
13
+ static std::atomic<int> g_next_handle{ 1 };
12
14
 
13
15
  // ============================================================================
14
16
  // Async Workers
15
17
  // ============================================================================
16
18
 
17
19
  class LoadModelWorker : public Napi::AsyncWorker {
18
- public:
19
- LoadModelWorker(
20
- Napi::Function& callback,
21
- const std::string& model_path,
22
- int n_gpu_layers,
23
- int n_ctx,
24
- int n_threads,
25
- bool debug,
26
- const std::string& chat_template,
27
- bool embedding
28
- )
29
- : Napi::AsyncWorker(callback)
30
- , model_path_(model_path)
31
- , n_gpu_layers_(n_gpu_layers)
32
- , n_ctx_(n_ctx)
33
- , n_threads_(n_threads)
34
- , debug_(debug)
35
- , chat_template_(chat_template)
36
- , embedding_(embedding)
37
- , handle_(-1)
38
- , success_(false)
39
- {}
20
+ public:
21
+ LoadModelWorker(Napi::Function & callback,
22
+ const std::string & model_path,
23
+ int n_gpu_layers,
24
+ int n_ctx,
25
+ int n_threads,
26
+ bool debug,
27
+ const std::string & chat_template,
28
+ bool embedding) :
29
+ Napi::AsyncWorker(callback),
30
+ model_path_(model_path),
31
+ n_gpu_layers_(n_gpu_layers),
32
+ n_ctx_(n_ctx),
33
+ n_threads_(n_threads),
34
+ debug_(debug),
35
+ chat_template_(chat_template),
36
+ embedding_(embedding),
37
+ handle_(-1),
38
+ success_(false) {}
40
39
 
41
40
  void Execute() override {
42
41
  auto model = std::make_unique<llama_wrapper::LlamaModel>();
43
42
 
44
43
  llama_wrapper::ModelParams model_params;
45
- model_params.model_path = model_path_;
46
- model_params.n_gpu_layers = n_gpu_layers_;
47
- model_params.debug = debug_;
44
+ model_params.model_path = model_path_;
45
+ model_params.n_gpu_layers = n_gpu_layers_;
46
+ model_params.debug = debug_;
48
47
  model_params.chat_template = chat_template_;
49
48
 
50
49
  if (!model->load(model_params)) {
@@ -53,7 +52,7 @@ public:
53
52
  }
54
53
 
55
54
  llama_wrapper::ContextParams ctx_params;
56
- ctx_params.n_ctx = n_ctx_;
55
+ ctx_params.n_ctx = n_ctx_;
57
56
  ctx_params.n_threads = n_threads_;
58
57
  ctx_params.embedding = embedding_;
59
58
 
@@ -74,52 +73,43 @@ public:
74
73
 
75
74
  void OnOK() override {
76
75
  Napi::HandleScope scope(Env());
77
- Callback().Call({
78
- Env().Null(),
79
- Napi::Number::New(Env(), handle_)
80
- });
76
+ Callback().Call({ Env().Null(), Napi::Number::New(Env(), handle_) });
81
77
  }
82
78
 
83
- void OnError(const Napi::Error& e) override {
79
+ void OnError(const Napi::Error & e) override {
84
80
  Napi::HandleScope scope(Env());
85
- Callback().Call({
86
- Napi::String::New(Env(), e.Message()),
87
- Env().Null()
88
- });
81
+ Callback().Call({ Napi::String::New(Env(), e.Message()), Env().Null() });
89
82
  }
90
83
 
91
- private:
84
+ private:
92
85
  std::string model_path_;
93
- int n_gpu_layers_;
94
- int n_ctx_;
95
- int n_threads_;
96
- bool debug_;
86
+ int n_gpu_layers_;
87
+ int n_ctx_;
88
+ int n_threads_;
89
+ bool debug_;
97
90
  std::string chat_template_;
98
- bool embedding_;
99
- int handle_;
100
- bool success_;
91
+ bool embedding_;
92
+ int handle_;
93
+ bool success_;
101
94
  };
102
95
 
103
96
  class GenerateWorker : public Napi::AsyncWorker {
104
- public:
105
- GenerateWorker(
106
- Napi::Function& callback,
107
- int handle,
108
- const std::vector<llama_wrapper::ChatMessage>& messages,
109
- const llama_wrapper::GenerationParams& params
110
- )
111
- : Napi::AsyncWorker(callback)
112
- , handle_(handle)
113
- , messages_(messages)
114
- , params_(params)
115
- {}
97
+ public:
98
+ GenerateWorker(Napi::Function & callback,
99
+ int handle,
100
+ const std::vector<llama_wrapper::ChatMessage> & messages,
101
+ const llama_wrapper::GenerationParams & params) :
102
+ Napi::AsyncWorker(callback),
103
+ handle_(handle),
104
+ messages_(messages),
105
+ params_(params) {}
116
106
 
117
107
  void Execute() override {
118
- llama_wrapper::LlamaModel* model = nullptr;
108
+ llama_wrapper::LlamaModel * model = nullptr;
119
109
 
120
110
  {
121
111
  std::lock_guard<std::mutex> lock(g_models_mutex);
122
- auto it = g_models.find(handle_);
112
+ auto it = g_models.find(handle_);
123
113
  if (it == g_models.end()) {
124
114
  SetError("Invalid model handle");
125
115
  return;
@@ -139,64 +129,54 @@ public:
139
129
  result.Set("completionTokens", Napi::Number::New(Env(), result_.completion_tokens));
140
130
  result.Set("finishReason", Napi::String::New(Env(), result_.finish_reason));
141
131
 
142
- Callback().Call({Env().Null(), result});
132
+ Callback().Call({ Env().Null(), result });
143
133
  }
144
134
 
145
- private:
146
- int handle_;
135
+ private:
136
+ int handle_;
147
137
  std::vector<llama_wrapper::ChatMessage> messages_;
148
- llama_wrapper::GenerationParams params_;
149
- llama_wrapper::GenerationResult result_;
138
+ llama_wrapper::GenerationParams params_;
139
+ llama_wrapper::GenerationResult result_;
150
140
  };
151
141
 
152
142
  // Thread-safe function context for streaming
153
143
  class StreamContext {
154
- public:
155
- StreamContext(Napi::Env env, Napi::Function callback)
156
- : callback_(Napi::Persistent(callback))
157
- , env_(env)
158
- {}
159
-
160
- Napi::FunctionReference callback_;
161
- Napi::Env env_;
144
+ public:
145
+ StreamContext(Napi::Env env, Napi::Function callback) : callback_(Napi::Persistent(callback)), env_(env) {}
146
+
147
+ Napi::FunctionReference callback_;
148
+ Napi::Env env_;
162
149
  llama_wrapper::GenerationResult result_;
163
150
  };
164
151
 
165
- void StreamCallJS(Napi::Env env, Napi::Function callback, StreamContext* context, const char* token) {
152
+ void StreamCallJS(Napi::Env env, Napi::Function callback, StreamContext * context, const char * token) {
166
153
  if (env != nullptr && callback != nullptr) {
167
154
  if (token != nullptr) {
168
155
  // Streaming token
169
- callback.Call({
170
- env.Null(),
171
- Napi::String::New(env, "token"),
172
- Napi::String::New(env, token)
173
- });
156
+ callback.Call({ env.Null(), Napi::String::New(env, "token"), Napi::String::New(env, token) });
174
157
  }
175
158
  }
176
159
  }
177
160
 
178
161
  class StreamGenerateWorker : public Napi::AsyncWorker {
179
- public:
180
- StreamGenerateWorker(
181
- Napi::Function& callback,
182
- int handle,
183
- const std::vector<llama_wrapper::ChatMessage>& messages,
184
- const llama_wrapper::GenerationParams& params,
185
- Napi::Function& token_callback
186
- )
187
- : Napi::AsyncWorker(callback)
188
- , handle_(handle)
189
- , messages_(messages)
190
- , params_(params)
191
- , token_callback_(Napi::Persistent(token_callback))
192
- {}
162
+ public:
163
+ StreamGenerateWorker(Napi::Function & callback,
164
+ int handle,
165
+ const std::vector<llama_wrapper::ChatMessage> & messages,
166
+ const llama_wrapper::GenerationParams & params,
167
+ Napi::Function & token_callback) :
168
+ Napi::AsyncWorker(callback),
169
+ handle_(handle),
170
+ messages_(messages),
171
+ params_(params),
172
+ token_callback_(Napi::Persistent(token_callback)) {}
193
173
 
194
174
  void Execute() override {
195
- llama_wrapper::LlamaModel* model = nullptr;
175
+ llama_wrapper::LlamaModel * model = nullptr;
196
176
 
197
177
  {
198
178
  std::lock_guard<std::mutex> lock(g_models_mutex);
199
- auto it = g_models.find(handle_);
179
+ auto it = g_models.find(handle_);
200
180
  if (it == g_models.end()) {
201
181
  SetError("Invalid model handle");
202
182
  return;
@@ -205,7 +185,7 @@ public:
205
185
  }
206
186
 
207
187
  // Collect tokens during generation
208
- result_ = model->generate_streaming(messages_, params_, [this](const std::string& token) {
188
+ result_ = model->generate_streaming(messages_, params_, [this](const std::string & token) {
209
189
  std::lock_guard<std::mutex> lock(tokens_mutex_);
210
190
  tokens_.push_back(token);
211
191
  return true;
@@ -216,10 +196,8 @@ public:
216
196
  Napi::HandleScope scope(Env());
217
197
 
218
198
  // Call token callback for each collected token
219
- for (const auto& token : tokens_) {
220
- token_callback_.Call({
221
- Napi::String::New(Env(), token)
222
- });
199
+ for (const auto & token : tokens_) {
200
+ token_callback_.Call({ Napi::String::New(Env(), token) });
223
201
  }
224
202
 
225
203
  // Final callback with result
@@ -229,37 +207,32 @@ public:
229
207
  result.Set("completionTokens", Napi::Number::New(Env(), result_.completion_tokens));
230
208
  result.Set("finishReason", Napi::String::New(Env(), result_.finish_reason));
231
209
 
232
- Callback().Call({Env().Null(), result});
210
+ Callback().Call({ Env().Null(), result });
233
211
  }
234
212
 
235
- private:
236
- int handle_;
213
+ private:
214
+ int handle_;
237
215
  std::vector<llama_wrapper::ChatMessage> messages_;
238
- llama_wrapper::GenerationParams params_;
239
- llama_wrapper::GenerationResult result_;
240
- Napi::FunctionReference token_callback_;
241
- std::vector<std::string> tokens_;
242
- std::mutex tokens_mutex_;
216
+ llama_wrapper::GenerationParams params_;
217
+ llama_wrapper::GenerationResult result_;
218
+ Napi::FunctionReference token_callback_;
219
+ std::vector<std::string> tokens_;
220
+ std::mutex tokens_mutex_;
243
221
  };
244
222
 
245
223
  class EmbedWorker : public Napi::AsyncWorker {
246
- public:
247
- EmbedWorker(
248
- Napi::Function& callback,
249
- int handle,
250
- const std::vector<std::string>& texts
251
- )
252
- : Napi::AsyncWorker(callback)
253
- , handle_(handle)
254
- , texts_(texts)
255
- {}
224
+ public:
225
+ EmbedWorker(Napi::Function & callback, int handle, const std::vector<std::string> & texts) :
226
+ Napi::AsyncWorker(callback),
227
+ handle_(handle),
228
+ texts_(texts) {}
256
229
 
257
230
  void Execute() override {
258
- llama_wrapper::LlamaModel* model = nullptr;
231
+ llama_wrapper::LlamaModel * model = nullptr;
259
232
 
260
233
  {
261
234
  std::lock_guard<std::mutex> lock(g_models_mutex);
262
- auto it = g_models.find(handle_);
235
+ auto it = g_models.find(handle_);
263
236
  if (it == g_models.end()) {
264
237
  SetError("Invalid model handle");
265
238
  return;
@@ -268,7 +241,7 @@ public:
268
241
  }
269
242
 
270
243
  result_ = model->embed(texts_);
271
-
244
+
272
245
  if (result_.embeddings.empty() && !texts_.empty()) {
273
246
  SetError("Failed to generate embeddings");
274
247
  return;
@@ -281,7 +254,7 @@ public:
281
254
  // Create embeddings array
282
255
  Napi::Array embeddings_arr = Napi::Array::New(Env(), result_.embeddings.size());
283
256
  for (size_t i = 0; i < result_.embeddings.size(); i++) {
284
- const auto& emb = result_.embeddings[i];
257
+ const auto & emb = result_.embeddings[i];
285
258
  Napi::Float32Array embedding = Napi::Float32Array::New(Env(), emb.size());
286
259
  for (size_t j = 0; j < emb.size(); j++) {
287
260
  embedding[j] = emb[j];
@@ -293,12 +266,12 @@ public:
293
266
  result.Set("embeddings", embeddings_arr);
294
267
  result.Set("totalTokens", Napi::Number::New(Env(), result_.total_tokens));
295
268
 
296
- Callback().Call({Env().Null(), result});
269
+ Callback().Call({ Env().Null(), result });
297
270
  }
298
271
 
299
- private:
300
- int handle_;
301
- std::vector<std::string> texts_;
272
+ private:
273
+ int handle_;
274
+ std::vector<std::string> texts_;
302
275
  llama_wrapper::EmbeddingResult result_;
303
276
  };
304
277
 
@@ -306,7 +279,7 @@ private:
306
279
  // N-API Functions
307
280
  // ============================================================================
308
281
 
309
- Napi::Value LoadModel(const Napi::CallbackInfo& info) {
282
+ Napi::Value LoadModel(const Napi::CallbackInfo & info) {
310
283
  Napi::Env env = info.Env();
311
284
 
312
285
  if (info.Length() < 2 || !info[0].IsObject() || !info[1].IsFunction()) {
@@ -314,30 +287,26 @@ Napi::Value LoadModel(const Napi::CallbackInfo& info) {
314
287
  return env.Null();
315
288
  }
316
289
 
317
- Napi::Object options = info[0].As<Napi::Object>();
290
+ Napi::Object options = info[0].As<Napi::Object>();
318
291
  Napi::Function callback = info[1].As<Napi::Function>();
319
292
 
320
- std::string model_path = options.Get("modelPath").As<Napi::String>().Utf8Value();
321
- int n_gpu_layers = options.Has("gpuLayers") ?
322
- options.Get("gpuLayers").As<Napi::Number>().Int32Value() : 99;
323
- int n_ctx = options.Has("contextSize") ?
324
- options.Get("contextSize").As<Napi::Number>().Int32Value() : 2048;
325
- int n_threads = options.Has("threads") ?
326
- options.Get("threads").As<Napi::Number>().Int32Value() : 4;
327
- bool debug = options.Has("debug") ?
328
- options.Get("debug").As<Napi::Boolean>().Value() : false;
329
- std::string chat_template = options.Has("chatTemplate") ?
330
- options.Get("chatTemplate").As<Napi::String>().Utf8Value() : "auto";
331
- bool embedding = options.Has("embedding") ?
332
- options.Get("embedding").As<Napi::Boolean>().Value() : false;
333
-
334
- auto worker = new LoadModelWorker(callback, model_path, n_gpu_layers, n_ctx, n_threads, debug, chat_template, embedding);
293
+ std::string model_path = options.Get("modelPath").As<Napi::String>().Utf8Value();
294
+ int n_gpu_layers = options.Has("gpuLayers") ? options.Get("gpuLayers").As<Napi::Number>().Int32Value() : 99;
295
+ int n_ctx = options.Has("contextSize") ? options.Get("contextSize").As<Napi::Number>().Int32Value() : 0;
296
+ int n_threads = options.Has("threads") ? options.Get("threads").As<Napi::Number>().Int32Value() : 4;
297
+ bool debug = options.Has("debug") ? options.Get("debug").As<Napi::Boolean>().Value() : false;
298
+ std::string chat_template =
299
+ options.Has("chatTemplate") ? options.Get("chatTemplate").As<Napi::String>().Utf8Value() : "auto";
300
+ bool embedding = options.Has("embedding") ? options.Get("embedding").As<Napi::Boolean>().Value() : false;
301
+
302
+ auto worker =
303
+ new LoadModelWorker(callback, model_path, n_gpu_layers, n_ctx, n_threads, debug, chat_template, embedding);
335
304
  worker->Queue();
336
305
 
337
306
  return env.Undefined();
338
307
  }
339
308
 
340
- Napi::Value UnloadModel(const Napi::CallbackInfo& info) {
309
+ Napi::Value UnloadModel(const Napi::CallbackInfo & info) {
341
310
  Napi::Env env = info.Env();
342
311
 
343
312
  if (info.Length() < 1 || !info[0].IsNumber()) {
@@ -349,7 +318,7 @@ Napi::Value UnloadModel(const Napi::CallbackInfo& info) {
349
318
 
350
319
  {
351
320
  std::lock_guard<std::mutex> lock(g_models_mutex);
352
- auto it = g_models.find(handle);
321
+ auto it = g_models.find(handle);
353
322
  if (it != g_models.end()) {
354
323
  g_models.erase(it);
355
324
  }
@@ -362,16 +331,16 @@ Napi::Value UnloadModel(const Napi::CallbackInfo& info) {
362
331
  std::vector<llama_wrapper::ChatMessage> ParseMessages(Napi::Array messages_arr) {
363
332
  std::vector<llama_wrapper::ChatMessage> messages;
364
333
  for (uint32_t i = 0; i < messages_arr.Length(); i++) {
365
- Napi::Object msg_obj = messages_arr.Get(i).As<Napi::Object>();
334
+ Napi::Object msg_obj = messages_arr.Get(i).As<Napi::Object>();
366
335
  llama_wrapper::ChatMessage msg;
367
- msg.role = msg_obj.Get("role").As<Napi::String>().Utf8Value();
336
+ msg.role = msg_obj.Get("role").As<Napi::String>().Utf8Value();
368
337
  msg.content = msg_obj.Get("content").As<Napi::String>().Utf8Value();
369
338
  messages.push_back(msg);
370
339
  }
371
340
  return messages;
372
341
  }
373
342
 
374
- Napi::Value Generate(const Napi::CallbackInfo& info) {
343
+ Napi::Value Generate(const Napi::CallbackInfo & info) {
375
344
  Napi::Env env = info.Env();
376
345
 
377
346
  if (info.Length() < 3 || !info[0].IsNumber() || !info[1].IsObject() || !info[2].IsFunction()) {
@@ -379,8 +348,8 @@ Napi::Value Generate(const Napi::CallbackInfo& info) {
379
348
  return env.Null();
380
349
  }
381
350
 
382
- int handle = info[0].As<Napi::Number>().Int32Value();
383
- Napi::Object options = info[1].As<Napi::Object>();
351
+ int handle = info[0].As<Napi::Number>().Int32Value();
352
+ Napi::Object options = info[1].As<Napi::Object>();
384
353
  Napi::Function callback = info[2].As<Napi::Function>();
385
354
 
386
355
  // Parse messages array
@@ -391,14 +360,10 @@ Napi::Value Generate(const Napi::CallbackInfo& info) {
391
360
  std::vector<llama_wrapper::ChatMessage> messages = ParseMessages(options.Get("messages").As<Napi::Array>());
392
361
 
393
362
  llama_wrapper::GenerationParams params;
394
- params.max_tokens = options.Has("maxTokens") ?
395
- options.Get("maxTokens").As<Napi::Number>().Int32Value() : 256;
396
- params.temperature = options.Has("temperature") ?
397
- options.Get("temperature").As<Napi::Number>().FloatValue() : 0.7f;
398
- params.top_p = options.Has("topP") ?
399
- options.Get("topP").As<Napi::Number>().FloatValue() : 0.9f;
400
- params.top_k = options.Has("topK") ?
401
- options.Get("topK").As<Napi::Number>().Int32Value() : 40;
363
+ params.max_tokens = options.Has("maxTokens") ? options.Get("maxTokens").As<Napi::Number>().Int32Value() : 256;
364
+ params.temperature = options.Has("temperature") ? options.Get("temperature").As<Napi::Number>().FloatValue() : 0.7f;
365
+ params.top_p = options.Has("topP") ? options.Get("topP").As<Napi::Number>().FloatValue() : 0.9f;
366
+ params.top_k = options.Has("topK") ? options.Get("topK").As<Napi::Number>().Int32Value() : 40;
402
367
 
403
368
  if (options.Has("stopSequences") && options.Get("stopSequences").IsArray()) {
404
369
  Napi::Array stop_arr = options.Get("stopSequences").As<Napi::Array>();
@@ -413,19 +378,20 @@ Napi::Value Generate(const Napi::CallbackInfo& info) {
413
378
  return env.Undefined();
414
379
  }
415
380
 
416
- Napi::Value GenerateStream(const Napi::CallbackInfo& info) {
381
+ Napi::Value GenerateStream(const Napi::CallbackInfo & info) {
417
382
  Napi::Env env = info.Env();
418
383
 
419
- if (info.Length() < 4 || !info[0].IsNumber() || !info[1].IsObject() ||
420
- !info[2].IsFunction() || !info[3].IsFunction()) {
421
- Napi::TypeError::New(env, "Expected (handle, options, tokenCallback, doneCallback)").ThrowAsJavaScriptException();
384
+ if (info.Length() < 4 || !info[0].IsNumber() || !info[1].IsObject() || !info[2].IsFunction() ||
385
+ !info[3].IsFunction()) {
386
+ Napi::TypeError::New(env, "Expected (handle, options, tokenCallback, doneCallback)")
387
+ .ThrowAsJavaScriptException();
422
388
  return env.Null();
423
389
  }
424
390
 
425
- int handle = info[0].As<Napi::Number>().Int32Value();
426
- Napi::Object options = info[1].As<Napi::Object>();
391
+ int handle = info[0].As<Napi::Number>().Int32Value();
392
+ Napi::Object options = info[1].As<Napi::Object>();
427
393
  Napi::Function token_callback = info[2].As<Napi::Function>();
428
- Napi::Function done_callback = info[3].As<Napi::Function>();
394
+ Napi::Function done_callback = info[3].As<Napi::Function>();
429
395
 
430
396
  // Parse messages array
431
397
  if (!options.Has("messages") || !options.Get("messages").IsArray()) {
@@ -435,14 +401,10 @@ Napi::Value GenerateStream(const Napi::CallbackInfo& info) {
435
401
  std::vector<llama_wrapper::ChatMessage> messages = ParseMessages(options.Get("messages").As<Napi::Array>());
436
402
 
437
403
  llama_wrapper::GenerationParams params;
438
- params.max_tokens = options.Has("maxTokens") ?
439
- options.Get("maxTokens").As<Napi::Number>().Int32Value() : 256;
440
- params.temperature = options.Has("temperature") ?
441
- options.Get("temperature").As<Napi::Number>().FloatValue() : 0.7f;
442
- params.top_p = options.Has("topP") ?
443
- options.Get("topP").As<Napi::Number>().FloatValue() : 0.9f;
444
- params.top_k = options.Has("topK") ?
445
- options.Get("topK").As<Napi::Number>().Int32Value() : 40;
404
+ params.max_tokens = options.Has("maxTokens") ? options.Get("maxTokens").As<Napi::Number>().Int32Value() : 256;
405
+ params.temperature = options.Has("temperature") ? options.Get("temperature").As<Napi::Number>().FloatValue() : 0.7f;
406
+ params.top_p = options.Has("topP") ? options.Get("topP").As<Napi::Number>().FloatValue() : 0.9f;
407
+ params.top_k = options.Has("topK") ? options.Get("topK").As<Napi::Number>().Int32Value() : 40;
446
408
 
447
409
  if (options.Has("stopSequences") && options.Get("stopSequences").IsArray()) {
448
410
  Napi::Array stop_arr = options.Get("stopSequences").As<Napi::Array>();
@@ -457,7 +419,7 @@ Napi::Value GenerateStream(const Napi::CallbackInfo& info) {
457
419
  return env.Undefined();
458
420
  }
459
421
 
460
- Napi::Value IsModelLoaded(const Napi::CallbackInfo& info) {
422
+ Napi::Value IsModelLoaded(const Napi::CallbackInfo & info) {
461
423
  Napi::Env env = info.Env();
462
424
 
463
425
  if (info.Length() < 1 || !info[0].IsNumber()) {
@@ -468,13 +430,13 @@ Napi::Value IsModelLoaded(const Napi::CallbackInfo& info) {
468
430
  int handle = info[0].As<Napi::Number>().Int32Value();
469
431
 
470
432
  std::lock_guard<std::mutex> lock(g_models_mutex);
471
- auto it = g_models.find(handle);
472
- bool loaded = it != g_models.end() && it->second->is_loaded();
433
+ auto it = g_models.find(handle);
434
+ bool loaded = it != g_models.end() && it->second->is_loaded();
473
435
 
474
436
  return Napi::Boolean::New(env, loaded);
475
437
  }
476
438
 
477
- Napi::Value Embed(const Napi::CallbackInfo& info) {
439
+ Napi::Value Embed(const Napi::CallbackInfo & info) {
478
440
  Napi::Env env = info.Env();
479
441
 
480
442
  if (info.Length() < 3 || !info[0].IsNumber() || !info[1].IsObject() || !info[2].IsFunction()) {
@@ -482,8 +444,8 @@ Napi::Value Embed(const Napi::CallbackInfo& info) {
482
444
  return env.Null();
483
445
  }
484
446
 
485
- int handle = info[0].As<Napi::Number>().Int32Value();
486
- Napi::Object options = info[1].As<Napi::Object>();
447
+ int handle = info[0].As<Napi::Number>().Int32Value();
448
+ Napi::Object options = info[1].As<Napi::Object>();
487
449
  Napi::Function callback = info[2].As<Napi::Function>();
488
450
 
489
451
  // Parse texts array
@@ -492,7 +454,7 @@ Napi::Value Embed(const Napi::CallbackInfo& info) {
492
454
  return env.Null();
493
455
  }
494
456
 
495
- Napi::Array texts_arr = options.Get("texts").As<Napi::Array>();
457
+ Napi::Array texts_arr = options.Get("texts").As<Napi::Array>();
496
458
  std::vector<std::string> texts;
497
459
  for (uint32_t i = 0; i < texts_arr.Length(); i++) {
498
460
  texts.push_back(texts_arr.Get(i).As<Napi::String>().Utf8Value());
@@ -504,6 +466,48 @@ Napi::Value Embed(const Napi::CallbackInfo& info) {
504
466
  return env.Undefined();
505
467
  }
506
468
 
469
+ Napi::Value Tokenize(const Napi::CallbackInfo & info) {
470
+ Napi::Env env = info.Env();
471
+
472
+ if (info.Length() < 2 || !info[0].IsNumber() || !info[1].IsObject()) {
473
+ Napi::TypeError::New(env, "Expected (handle, options)").ThrowAsJavaScriptException();
474
+ return env.Null();
475
+ }
476
+
477
+ int handle = info[0].As<Napi::Number>().Int32Value();
478
+ Napi::Object options = info[1].As<Napi::Object>();
479
+
480
+ if (!options.Has("text") || !options.Get("text").IsString()) {
481
+ Napi::TypeError::New(env, "Expected text string in options").ThrowAsJavaScriptException();
482
+ return env.Null();
483
+ }
484
+
485
+ std::string text = options.Get("text").As<Napi::String>().Utf8Value();
486
+ bool add_bos = options.Has("addBos") ? options.Get("addBos").As<Napi::Boolean>().Value() : true;
487
+
488
+ llama_wrapper::LlamaModel * model = nullptr;
489
+
490
+ {
491
+ std::lock_guard<std::mutex> lock(g_models_mutex);
492
+ auto it = g_models.find(handle);
493
+ if (it == g_models.end()) {
494
+ Napi::Error::New(env, "Invalid model handle").ThrowAsJavaScriptException();
495
+ return env.Null();
496
+ }
497
+ model = it->second.get();
498
+ }
499
+
500
+ std::vector<int32_t> tokens = model->tokenize(text, add_bos);
501
+
502
+ // Create Int32Array result
503
+ Napi::Int32Array result = Napi::Int32Array::New(env, tokens.size());
504
+ for (size_t i = 0; i < tokens.size(); i++) {
505
+ result[i] = tokens[i];
506
+ }
507
+
508
+ return result;
509
+ }
510
+
507
511
  // ============================================================================
508
512
  // Module Initialization
509
513
  // ============================================================================
@@ -515,8 +519,8 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
515
519
  exports.Set("generateStream", Napi::Function::New(env, GenerateStream));
516
520
  exports.Set("isModelLoaded", Napi::Function::New(env, IsModelLoaded));
517
521
  exports.Set("embed", Napi::Function::New(env, Embed));
522
+ exports.Set("tokenize", Napi::Function::New(env, Tokenize));
518
523
  return exports;
519
524
  }
520
525
 
521
526
  NODE_API_MODULE(llama_binding, Init)
522
-